Unverified Commit 8f4b313c authored by wangxiyuan's avatar wangxiyuan Committed by GitHub
Browse files

[Misc] rename torch_dtype to dtype (#26695)


Signed-off-by: default avatarwangxiyuan <wangxiyuan1007@gmail.com>
parent f93e3480
......@@ -114,7 +114,7 @@ class FlashConfig(PretrainedConfig):
attention_dropout=0.0,
mla_scale_q_lora=False,
mla_scale_kv_lora=False,
torch_dtype="bfloat16",
dtype="bfloat16",
params_dtype="bfloat16",
router_dtype="float32",
router_bias=False,
......@@ -130,7 +130,7 @@ class FlashConfig(PretrainedConfig):
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
dtype=dtype,
params_dtype=params_dtype,
router_dtype=router_dtype,
topk_method=topk_method,
......
......@@ -987,7 +987,7 @@ class NemotronH_Nano_VL_V2(
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.torch_dtype
self.language_model.config.dtype
)
# Construct the vision projection.
......@@ -1008,7 +1008,7 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
)
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
self.config = config
self.model_config = vllm_config.model_config
......
......@@ -338,7 +338,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.torch_dtype,
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
......@@ -847,7 +847,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
dtype=config.dtype,
),
)
self.ffn_layer_scale = torch.nn.Parameter(
......@@ -855,7 +855,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
dtype=config.dtype,
),
)
......
......@@ -530,7 +530,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
......
......@@ -157,7 +157,7 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
with torch.device("meta"):
seq_cls_model = AutoModelForSequenceClassification.from_config(
self.config,
torch_dtype=self.model_config.dtype,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
......
......@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
return supported
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()
......
......@@ -563,7 +563,7 @@ class Platform:
return False
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
def check_if_supports_dtype(cls, dtype: torch.dtype):
"""
Check if the dtype is supported by the current platform.
"""
......
......@@ -484,8 +484,8 @@ class RocmPlatform(Platform):
return True
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()
......
......@@ -236,8 +236,8 @@ class XPUPlatform(Platform):
return torch.xpu.device_count()
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
device_name = cls.get_device_name().lower()
# client gpu a770
if device_name.count("a770") > 0:
......
......@@ -806,7 +806,7 @@ def create_kv_caches_with_random_flash(
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
assert cache_layout in ("NHD", "HND")
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
......@@ -819,7 +819,7 @@ def create_kv_caches_with_random_flash(
for _ in range(num_layers):
key_value_cache = torch.empty(
size=kv_cache_allocation_shape, dtype=torch_dtype, device=device
size=kv_cache_allocation_shape, dtype=dtype, device=device
).permute(*stride_order)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
......@@ -851,14 +851,14 @@ def create_kv_caches_with_random(
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches: list[torch.Tensor] = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
......@@ -870,9 +870,7 @@ def create_kv_caches_with_random(
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches: list[torch.Tensor] = []
for _ in range(num_layers):
value_cache = torch.empty(
size=value_cache_shape, dtype=torch_dtype, device=device
)
value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment