"vscode:/vscode.git/clone" did not exist on "1d3b429f40888d935e15608b2c7707f5b028564e"
Unverified Commit 76524b70 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix torch compile for deepseek-v2 (#1442)

parent 3a6e0418
......@@ -41,6 +41,9 @@ if TYPE_CHECKING:
def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
# NOTE: FusedMoE torch native implementaiton is not efficient
if "FusedMoE" in sub.__class__.__name__:
continue
if reverse:
sub._forward_method = sub.forward_cuda
setattr(sub, "is_torch_compile", False)
......@@ -105,7 +108,15 @@ class CudaGraphRunner:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.max_torch_compile_bs
]
if self.use_torch_compile
else []
)
# Common inputs
self.max_bs = max(self.capture_bs)
......
......@@ -653,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -110,6 +110,7 @@ class ServerArgs:
disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
max_torch_compile_bs: int = 32
torchao_config: str = ""
enable_p2p_check: bool = False
enable_mla: bool = False
......@@ -523,6 +524,12 @@ class ServerArgs:
action="store_true",
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--max-torch-compile-bs",
type=int,
default=ServerArgs.max_torch_compile_bs,
help="Set the maximum batch size when using torch compile.",
)
parser.add_argument(
"--torchao-config",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment