Unverified Commit f024795e authored by YyWangCS's avatar YyWangCS Committed by GitHub
Browse files

Replace torch.jit.script with torch.compile in get_masked_input_and_mask to...

Replace torch.jit.script with torch.compile in get_masked_input_and_mask to fix benchmark underreporting (#8733)
parent b102353f
...@@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import (
method_has_implemented_embedding, method_has_implemented_embedding,
) )
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs from sglang.srt.utils import (
cpu_has_amx_support,
get_compiler_backend,
is_cpu,
set_weight_attrs,
)
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
...@@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices: ...@@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices:
assert self.num_added_elements <= self.num_added_elements_padded assert self.num_added_elements <= self.num_added_elements_padded
@torch.jit.script @torch.compile(dynamic=True, backend=get_compiler_backend())
def get_masked_input_and_mask( def get_masked_input_and_mask(
input_: torch.Tensor, input_: torch.Tensor,
org_vocab_start_index: int, org_vocab_start_index: int,
...@@ -126,7 +131,7 @@ def get_masked_input_and_mask( ...@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
added_vocab_start_index: int, added_vocab_start_index: int,
added_vocab_end_index: int, added_vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below # torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast # into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & ( added_vocab_mask = (input_ >= added_vocab_start_index) & (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment