"...resnet50_tensorflow.git" did not exist on "27fb855b027ead16d2616dcb59c67409a2176b7f"
Unverified Commit 70b68029 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize conflicts between CUDA graph and vocab mask tensors (#1392)

parent f3d32f88
...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
...@@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module): ...@@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module):
output_top_logprobs=None, output_top_logprobs=None,
) )
# A dummy to make this work return logits_output
sample_output = SampleOutput(
success=torch.full(
size=(scores.shape[0],),
fill_value=True,
dtype=torch.bool,
),
probs=torch.full(
size=(scores.shape[0], 1),
fill_value=1.0,
dtype=torch.float16,
),
batch_next_token_ids=torch.full(
size=(scores.shape[0],),
fill_value=0,
dtype=torch.long,
),
)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = self.param_dict params_dict = self.param_dict
......
...@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
self.scale_width = self.config.hidden_size / self.config.dim_model_base self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
lm_head_weight = self.model.embed_tokens.weight lm_head_weight = self.model.embed_tokens.weight
else: else:
lm_head_weight = self.lm_head.weight lm_head_weight = self.lm_head.weight
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata input_ids, hidden_states, lm_head_weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -572,7 +571,6 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -572,7 +571,6 @@ class MiniCPM3ForCausalLM(nn.Module):
self.scale_width = self.config.hidden_size / self.config.dim_model_base self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -590,11 +588,9 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -590,11 +588,9 @@ class MiniCPM3ForCausalLM(nn.Module):
lm_head_weight = self.model.embed_tokens.weight lm_head_weight = self.model.embed_tokens.weight
else: else:
lm_head_weight = self.lm_head.weight lm_head_weight = self.lm_head.weight
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata input_ids, hidden_states, lm_head_weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module): ...@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config) self.model = MixtralModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module): ...@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module): ...@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
): ):
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm ...@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
Qwen2Config = None Qwen2Config = None
...@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad() @torch.no_grad()
...@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
if not get_embedding: if not get_embedding:
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
else: else:
return self.pooler(hidden_states, input_metadata) return self.pooler(hidden_states, input_metadata)
......
...@@ -47,7 +47,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -47,7 +47,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -365,7 +364,6 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -365,7 +364,6 @@ class Qwen2MoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -376,11 +374,9 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -376,11 +374,9 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module): ...@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
self.model = StableLMEpochModel(config, quant_config=quant_config) self.model = StableLMEpochModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module): ...@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
...@@ -307,7 +306,6 @@ class XverseForCausalLM(nn.Module): ...@@ -307,7 +306,6 @@ class XverseForCausalLM(nn.Module):
self.model = XverseModel(config, quant_config=quant_config) self.model = XverseModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.param_dict = dict(self.named_parameters()) self.param_dict = dict(self.named_parameters())
...@@ -320,12 +318,9 @@ class XverseForCausalLM(nn.Module): ...@@ -320,12 +318,9 @@ class XverseForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
# print(f"{hidden_states=}") return self.logits_processor(
logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights( def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
......
...@@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -383,7 +382,6 @@ class XverseMoeForCausalLM(nn.Module): ...@@ -383,7 +382,6 @@ class XverseMoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.param_dict = dict(self.named_parameters()) self.param_dict = dict(self.named_parameters())
...@@ -395,11 +393,9 @@ class XverseMoeForCausalLM(nn.Module): ...@@ -395,11 +393,9 @@ class XverseMoeForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -41,7 +41,6 @@ class SamplingBatchInfo: ...@@ -41,7 +41,6 @@ class SamplingBatchInfo:
# Vocab bias and min_ps are not supported in CUDA graph # Vocab bias and min_ps are not supported in CUDA graph
return ( return (
self.logit_bias is None self.logit_bias is None
and self.vocab_mask is None
and self.linear_penalties is None and self.linear_penalties is None
and self.scaling_penalties is None and self.scaling_penalties is None
and not self.need_min_p_sampling and not self.need_min_p_sampling
...@@ -50,9 +49,11 @@ class SamplingBatchInfo: ...@@ -50,9 +49,11 @@ class SamplingBatchInfo:
@classmethod @classmethod
def dummy_one(cls, max_bs: int, vocab_size: int): def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size) ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda") with torch.device("cuda"):
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda") ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda") ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
return ret return ret
def __getitem__(self, key): def __getitem__(self, key):
...@@ -64,6 +65,7 @@ class SamplingBatchInfo: ...@@ -64,6 +65,7 @@ class SamplingBatchInfo:
temperatures=self.temperatures[key], temperatures=self.temperatures[key],
top_ps=self.top_ps[key], top_ps=self.top_ps[key],
top_ks=self.top_ks[key], top_ks=self.top_ks[key],
vocab_mask=self.vocab_mask[key],
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -77,6 +79,11 @@ class SamplingBatchInfo: ...@@ -77,6 +79,11 @@ class SamplingBatchInfo:
self.top_ps[:bs] = other.top_ps self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks self.top_ks[:bs] = other.top_ks
if other.vocab_mask is None:
self.vocab_mask[:bs].fill_(False)
else:
self.vocab_mask[:bs] = other.vocab_mask
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda" device = "cuda"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment