Unverified Commit 76524b70 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix torch compile for deepseek-v2 (#1442)

parent 3a6e0418
...@@ -41,6 +41,9 @@ if TYPE_CHECKING: ...@@ -41,6 +41,9 @@ if TYPE_CHECKING:
def _to_torch(model: torch.nn.Module, reverse: bool = False): def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
# NOTE: FusedMoE torch native implementaiton is not efficient
if "FusedMoE" in sub.__class__.__name__:
continue
if reverse: if reverse:
sub._forward_method = sub.forward_cuda sub._forward_method = sub.forward_cuda
setattr(sub, "is_torch_compile", False) setattr(sub, "is_torch_compile", False)
...@@ -105,7 +108,15 @@ class CudaGraphRunner: ...@@ -105,7 +108,15 @@ class CudaGraphRunner:
self.capture_bs = list(range(1, 32)) + [64, 128] self.capture_bs = list(range(1, 32)) + [64, 128]
else: else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else [] self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.max_torch_compile_bs
]
if self.use_torch_compile
else []
)
# Common inputs # Common inputs
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
......
...@@ -653,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -653,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -110,6 +110,7 @@ class ServerArgs: ...@@ -110,6 +110,7 @@ class ServerArgs:
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
max_torch_compile_bs: int = 32
torchao_config: str = "" torchao_config: str = ""
enable_p2p_check: bool = False enable_p2p_check: bool = False
enable_mla: bool = False enable_mla: bool = False
...@@ -523,6 +524,12 @@ class ServerArgs: ...@@ -523,6 +524,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Optimize the model with torch.compile. Experimental feature.", help="Optimize the model with torch.compile. Experimental feature.",
) )
parser.add_argument(
"--max-torch-compile-bs",
type=int,
default=ServerArgs.max_torch_compile_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument( parser.add_argument(
"--torchao-config", "--torchao-config",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment