Unverified Commit 8f4b313c authored by wangxiyuan's avatar wangxiyuan Committed by GitHub
Browse files

[Misc] rename torch_dtype to dtype (#26695)


Signed-off-by: default avatarwangxiyuan <wangxiyuan1007@gmail.com>
parent f93e3480
...@@ -114,7 +114,7 @@ class FlashConfig(PretrainedConfig): ...@@ -114,7 +114,7 @@ class FlashConfig(PretrainedConfig):
attention_dropout=0.0, attention_dropout=0.0,
mla_scale_q_lora=False, mla_scale_q_lora=False,
mla_scale_kv_lora=False, mla_scale_kv_lora=False,
torch_dtype="bfloat16", dtype="bfloat16",
params_dtype="bfloat16", params_dtype="bfloat16",
router_dtype="float32", router_dtype="float32",
router_bias=False, router_bias=False,
...@@ -130,7 +130,7 @@ class FlashConfig(PretrainedConfig): ...@@ -130,7 +130,7 @@ class FlashConfig(PretrainedConfig):
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype, dtype=dtype,
params_dtype=params_dtype, params_dtype=params_dtype,
router_dtype=router_dtype, router_dtype=router_dtype,
topk_method=topk_method, topk_method=topk_method,
......
...@@ -987,7 +987,7 @@ class NemotronH_Nano_VL_V2( ...@@ -987,7 +987,7 @@ class NemotronH_Nano_VL_V2(
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
) )
self.vision_model = self.get_vit_model_from_radio_config(config).to( self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.torch_dtype self.language_model.config.dtype
) )
# Construct the vision projection. # Construct the vision projection.
...@@ -1008,7 +1008,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1008,7 +1008,7 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(), ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
) )
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
self.config = config self.config = config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
......
...@@ -338,7 +338,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -338,7 +338,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size=None, group_size=None,
norm_before_gate=True, norm_before_gate=True,
device=current_platform.current_device(), device=current_platform.current_device(),
dtype=config.torch_dtype, dtype=config.dtype,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
...@@ -847,7 +847,7 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -847,7 +847,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1, 1,
1, 1,
config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.dtype,
), ),
) )
self.ffn_layer_scale = torch.nn.Parameter( self.ffn_layer_scale = torch.nn.Parameter(
...@@ -855,7 +855,7 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -855,7 +855,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1, 1,
1, 1,
config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.dtype,
), ),
) )
......
...@@ -530,7 +530,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -530,7 +530,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
with init_on_device_without_buffers("meta"): with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
self.config, self.config,
torch_dtype=self.model_config.dtype, dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
......
...@@ -157,7 +157,7 @@ class TransformersForSequenceClassification(TransformersPoolingBase): ...@@ -157,7 +157,7 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
with torch.device("meta"): with torch.device("meta"):
seq_cls_model = AutoModelForSequenceClassification.from_config( seq_cls_model = AutoModelForSequenceClassification.from_config(
self.config, self.config,
torch_dtype=self.model_config.dtype, dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
......
...@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform): ...@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
return supported return supported
@classmethod @classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102 if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80): if not cls.has_device_capability(80):
capability = cls.get_device_capability() capability = cls.get_device_capability()
gpu_name = cls.get_device_name() gpu_name = cls.get_device_name()
......
...@@ -563,7 +563,7 @@ class Platform: ...@@ -563,7 +563,7 @@ class Platform:
return False return False
@classmethod @classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
""" """
Check if the dtype is supported by the current platform. Check if the dtype is supported by the current platform.
""" """
......
...@@ -484,8 +484,8 @@ class RocmPlatform(Platform): ...@@ -484,8 +484,8 @@ class RocmPlatform(Platform):
return True return True
@classmethod @classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102 if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80): if not cls.has_device_capability(80):
capability = cls.get_device_capability() capability = cls.get_device_capability()
gpu_name = cls.get_device_name() gpu_name = cls.get_device_name()
......
...@@ -236,8 +236,8 @@ class XPUPlatform(Platform): ...@@ -236,8 +236,8 @@ class XPUPlatform(Platform):
return torch.xpu.device_count() return torch.xpu.device_count()
@classmethod @classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102 if dtype == torch.bfloat16: # noqa: SIM102
device_name = cls.get_device_name().lower() device_name = cls.get_device_name().lower()
# client gpu a770 # client gpu a770
if device_name.count("a770") > 0: if device_name.count("a770") > 0:
......
...@@ -806,7 +806,7 @@ def create_kv_caches_with_random_flash( ...@@ -806,7 +806,7 @@ def create_kv_caches_with_random_flash(
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
assert cache_layout in ("NHD", "HND") assert cache_layout in ("NHD", "HND")
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
...@@ -819,7 +819,7 @@ def create_kv_caches_with_random_flash( ...@@ -819,7 +819,7 @@ def create_kv_caches_with_random_flash(
for _ in range(num_layers): for _ in range(num_layers):
key_value_cache = torch.empty( key_value_cache = torch.empty(
size=kv_cache_allocation_shape, dtype=torch_dtype, device=device size=kv_cache_allocation_shape, dtype=dtype, device=device
).permute(*stride_order) ).permute(*stride_order)
if cache_dtype in ["auto", "half", "bfloat16", "float"]: if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale) key_value_cache.uniform_(-scale, scale)
...@@ -851,14 +851,14 @@ def create_kv_caches_with_random( ...@@ -851,14 +851,14 @@ def create_kv_caches_with_random(
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5 scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches: list[torch.Tensor] = [] key_caches: list[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]: if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale) key_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8": elif cache_dtype == "fp8":
...@@ -870,9 +870,7 @@ def create_kv_caches_with_random( ...@@ -870,9 +870,7 @@ def create_kv_caches_with_random(
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
value_cache = torch.empty( value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device)
size=value_cache_shape, dtype=torch_dtype, device=device
)
if cache_dtype in ["auto", "half", "bfloat16", "float"]: if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale) value_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8": elif cache_dtype == "fp8":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment