Unverified Commit 14203432 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(compile_utils, ep_moe): update environment variable and dtype check (#12034)

parent d7f0d88f
...@@ -36,7 +36,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -36,7 +36,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` | | `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` |
| `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` | | `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |
| `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` | | `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` |
| `SGL_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` | | `SGLANG_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` |
| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | | `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
| `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` | | `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` |
......
...@@ -26,7 +26,7 @@ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "fal ...@@ -26,7 +26,7 @@ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "fal
# Force redirect deep_gemm cache_dir # Force redirect deep_gemm cache_dir
os.environ["DG_JIT_CACHE_DIR"] = os.getenv( os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm") "SGLANG_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
) )
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
......
...@@ -440,9 +440,10 @@ class DeepEPMoE(FusedMoE): ...@@ -440,9 +440,10 @@ class DeepEPMoE(FusedMoE):
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
assert ( assert hidden_states_scale.dtype == torch.float32 or (
hidden_states_scale.dtype == torch.float32 deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}" and hidden_states_scale.dtype == torch.int32
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}, DEEPGEMM_SCALE_UE8M0: {deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0}"
# GroupGemm-0 # GroupGemm-0
num_groups, m, k = hidden_states.size() num_groups, m, k = hidden_states.size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment