Unverified Commit c2942907 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[feature] enable pre compile jit deep_gemm (#5580)

parent e69a2190
"""
Compile DeepGEMM Kernels for a model with specify server arguments
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
It accepts server arguments (the same as launch_server.py).
Usage:
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
"""
import argparse
import dataclasses
import multiprocessing
import os
import time
import requests
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True)
# Reduce warning
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
@dataclasses.dataclass
class CompileArgs:
timeout: int = 3600
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
@warmup("compile-deep-gemm")
async def warm_up_compile(tokenizer_manager: TokenizerManager):
print("\nGenerate warm up request for compiling DeepGEMM...\n")
generate_req_input = GenerateReqInput(
input_ids=[0, 1, 2, 3],
sampling_params={
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
)
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
def launch_server_internal(server_args):
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_process_tree(os.getpid(), include_parent=False)
def launch_server_process_and_send_one_request(
server_args: ServerArgs, compile_args: CompileArgs
):
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
proc.start()
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = compile_args.timeout
start_time = time.time()
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
if response.status_code == 200:
return proc
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError(
"DeepGEMM Kernels compilation timeout."
"\n\nFeel free and please restart the command."
)
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
# Disbale cuda graph and torch compile to save time
server_args.disable_cuda_graph = True
server_args.enable_torch_compile = False
print(f"Disable CUDA Graph and Torch Compile to save time...")
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
server_args.watchdog_timeout = compile_args.timeout
server_args.warmups = "compile-deep-gemm"
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
print(
"Begin DeepGEMM Kernels compilation...\n"
"It may take a long time and timeout maybe raised "
"while the compilation is still in progress.\n"
"Just feel free to restart the command "
"until the compilation is fully finished.\n"
)
proc = launch_server_process_and_send_one_request(server_args, compile_args)
kill_process_tree(proc.pid)
print("\nDeepGEMM Kernels compilation finished successfully.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
CompileArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
compile_args = CompileArgs.from_cli_args(args)
refine_server_args(server_args, compile_args)
run_compile(server_args, compile_args)
import logging
import os
from contextlib import contextmanager
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple
import torch
from tqdm.contrib.concurrent import thread_map
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
_ENABLE_JIT_DEEPGEMM = False
if is_cuda():
import deep_gemm
from deep_gemm import get_num_sms
from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
from deep_gemm.jit_kernels.m_grouped_gemm import (
template as deep_gemm_grouped_gemm_template,
)
from deep_gemm.jit_kernels.tuner import jit_tuner
sm_version = get_device_sm()
if sm_version == 90:
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
_ENABLE_JIT_DEEPGEMM = True
logger = logging.getLogger(__name__)
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
)
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir
os.environ["DG_CACHE_DIR"] = os.getenv(
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
)
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _BUILTIN_M_LIST
global _DO_COMPILE
# Generate m_max
m_max = 1024 * 16
if server_args.chunked_prefill_size < 1:
m_max = 1024 * 64
elif server_args.chunked_prefill_size > 8192:
m_max = server_args.chunked_prefill_size * 2
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST = list(range(1, m_max + 1))
# Check if is the first rank on node
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
class DeepGemmKernelType(IntEnum):
GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
GEMM_NT_F8F8BF16 = auto()
@dataclass
class DeepGemmKernelHelper:
name: str
compile_func: Callable[
[
int,
int,
int,
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
],
None,
]
configure_func: Callable[
[int, int, int, int, int],
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
]
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
def _compile_warning_1():
if not _IN_PRE_COMPILE_STAGE:
logger.warning(
"Entering DeepGEMM JIT Pre-Complie session. "
"And it may takes a long time(Typically 10-20 mins) "
"if you have not run `sglang.compile_deep_gemm`. "
"Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to reduce the overhead if you have not run it before. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
def _compile_warning_2():
logger.warning(
"Entering DeepGEMM JIT Single Kernel Complie session. "
"And it will makes inference throughput becomes flaky. "
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to solve this issue. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
n: int,
k: int,
num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
# Auto-tuning with compilation
global deep_gemm_includes, deep_gemm_grouped_gemm_template
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": "GroupedMasked",
},
space=(),
includes=deep_gemm_includes,
arg_defs=(
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("grouped_layout", torch.int32),
("m", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_grouped_gemm_template,
args=[],
)
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
n: int,
k: int,
num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
global deep_gemm_includes, deep_gemm_grouped_gemm_template
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": "GroupedContiguous",
},
space=(),
includes=deep_gemm_includes,
arg_defs=(
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("grouped_layout", torch.int32),
("m", int),
("num_groups", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_grouped_gemm_template,
args=[],
)
def _compile_gemm_nt_f8f8bf16_one(
n: int,
k: int,
_: int, # _ is a dummy parameter to align with other interfaces
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
global deep_gemm_includes, deep_gemm_gemm_template
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_ = jit_tuner.compile_and_tune(
name="gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
},
space=(),
includes=deep_gemm_includes,
arg_defs=(
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("m", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_gemm_template,
args=[],
)
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
m, n, k, num_groups, num_sms, is_grouped_masked=True
),
),
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
m, n, k, 1, num_sms, is_grouped_contiguous=True
),
),
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
name="gemm_fp8_fp8_bf16_nt",
compile_func=_compile_gemm_nt_f8f8bf16_one,
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
m, n, k, 1, num_sms
),
),
}
def _maybe_compile_deep_gemm_one_type_all(
kernel_type: DeepGemmKernelType,
n: int,
k: int,
num_groups: int,
m_list: Optional[List[int]] = None,
) -> None:
global _INITIALIZATION_DICT
global _BUILTIN_M_LIST
query_key = (kernel_type, n, k, num_groups)
if (
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
and _DO_COMPILE
and _INITIALIZATION_DICT.get(query_key) is None
):
_INITIALIZATION_DICT[query_key] = True
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
_compile_warning_1()
logger.info(
f"Try DeepGEMM JIT Compiling for "
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
)
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
num_sms = get_num_sms()
collected_configs = set()
for m in m_list if m_list is not None else _BUILTIN_M_LIST:
# Put config into set to get unique configs and reduce cases to be compiled
collected_configs.add(
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
)
compile_func = lambda config: kernel_helper.compile_func(
n, k, num_groups, config
)
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
def grouped_gemm_nt_f8f8bf16_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
):
num_groups, _, k = lhs[0].shape
_, n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(expected_m, n, k, kernel_type):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
lhs, rhs, out, masked_m, expected_m
)
def grouped_gemm_nt_f8f8bf16_contig(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
m_indices: torch.Tensor,
):
m, k = lhs[0].shape
num_groups, n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
with _log_jit_build(m, n, k, kernel_type):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
def gemm_nt_f8f8bf16(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
):
m, k = lhs[0].shape
n, _ = rhs[0].shape
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
with _log_jit_build(m, n, k, kernel_type):
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
@contextmanager
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
if _IN_PRE_COMPILE_STAGE:
yield
return
from deep_gemm.jit.runtime import RuntimeCache
origin_func = RuntimeCache.__getitem__
def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
if ret is None:
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
_compile_warning_2()
logger.warning(
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
)
return ret
RuntimeCache.__getitem__ = __patched_func
yield
RuntimeCache.__getitem__ = origin_func
......@@ -16,19 +16,17 @@ import functools
import json
import logging
import os
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_device_core_count,
get_device_name,
get_device_sm,
is_cuda,
is_hip,
supports_custom_op,
......@@ -43,22 +41,16 @@ else:
fp8_max = torch.finfo(_fp8_type).max
fp8_min = -fp8_max
_enable_jit_deepgemm = False
_enable_jit_deepgemm_bmm = False
if _is_cuda:
import deep_gemm
from sgl_kernel import (
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
sm_version = get_device_sm()
if sm_version == 90:
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
_enable_jit_deepgemm = True
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
_enable_jit_deepgemm_bmm = True
from sglang.srt.layers.quantization.deep_gemm import (
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
)
logger = logging.getLogger(__name__)
......@@ -71,10 +63,7 @@ if supports_custom_op():
Bs: torch.Tensor,
C: torch.Tensor,
) -> None:
M, K = A.shape
N, _ = B.shape
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
def deep_gemm_fp8_fp8_bf16_nt_fake(
A: torch.Tensor,
......@@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs(
return None
@contextmanager
def _log_jit_build(M: int, N: int, K: int):
from deep_gemm.jit.runtime import RuntimeCache
origin_func = RuntimeCache.__getitem__
def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
if ret is None:
logger.warning(
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
)
return ret
RuntimeCache.__getitem__ = __patched_func
yield
RuntimeCache.__getitem__ = origin_func
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
......@@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul(
)
# deepgemm only support bf16
if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
......
......@@ -12,8 +12,8 @@ try:
except ImportError:
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm,
per_token_group_quant_fp8,
scaled_fp8_quant,
sglang_per_token_quant_fp8,
......@@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear(
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else:
if _enable_jit_deepgemm:
if _ENABLE_JIT_DEEPGEMM:
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d,
block_size[1],
......
......@@ -42,6 +42,10 @@ from sglang.srt.layers.dp_attention import (
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
from sglang.srt.layers.quantization.deep_gemm import (
_ENABLE_JIT_DEEPGEMM,
update_deep_gemm_config,
)
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
......@@ -169,6 +173,10 @@ class ModelRunner:
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
# Update deep gemm configure
if _ENABLE_JIT_DEEPGEMM:
update_deep_gemm_config(gpu_id, server_args)
# If it is a draft model tp_group can be different.
self.initialize(min_per_gpu_memory)
......
......@@ -57,8 +57,8 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm_bmm,
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
)
......@@ -86,8 +86,11 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.quantization.deep_gemm import (
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
)
else:
from vllm._custom_ops import awq_dequantize
......@@ -702,7 +705,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
......@@ -751,7 +754,7 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_bmm_output = attn_output.new_empty(
(self.num_local_heads, aligned_m, self.v_head_dim)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
(attn_output_val, attn_output_scale),
(self.w_vc, self.w_scale_v),
attn_bmm_output,
......@@ -1520,7 +1523,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if (
_is_cuda
and _enable_jit_deepgemm_bmm
and _ENABLE_JIT_DEEPGEMM
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
......
......@@ -98,6 +98,16 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
return value in truthy_values
def get_int_env_var(name: str, default: int = 0) -> int:
value = os.getenv(name)
if value is None or not value.strip():
return default
try:
return int(value)
except ValueError:
return default
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def is_hip() -> bool:
return torch.version.hip is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment