Unverified Commit 14203432 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(compile_utils, ep_moe): update environment variable and dtype check (#12034)

parent d7f0d88f
......@@ -36,7 +36,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` |
| `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |
| `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` |
| `SGL_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` |
| `SGLANG_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` |
| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
| `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` |
......
......@@ -26,7 +26,7 @@ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "fal
# Force redirect deep_gemm cache_dir
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
"SGLANG_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
)
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
......
......@@ -440,9 +440,10 @@ class DeepEPMoE(FusedMoE):
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
assert (
hidden_states_scale.dtype == torch.float32
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
assert hidden_states_scale.dtype == torch.float32 or (
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and hidden_states_scale.dtype == torch.int32
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}, DEEPGEMM_SCALE_UE8M0: {deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0}"
# GroupGemm-0
num_groups, m, k = hidden_states.size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment