Commit c0bdac11 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch '0.9.2+das.opt1.alpha.dtk25041_rms_quant_squash_deepseekw4a8_push_3' into 'v0.9.2-dev'

deepseek-r1-w4a8使用rmsquant融合算子及横向融合

See merge request dcutoolkit/deeplearing/vllm!205
parents fb3c32c6 fc443d52
...@@ -166,6 +166,7 @@ if TYPE_CHECKING: ...@@ -166,6 +166,7 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_TRITON_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1104,6 +1105,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1104,6 +1105,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -33,6 +33,12 @@ from vllm.platforms import current_platform ...@@ -33,6 +33,12 @@ from vllm.platforms import current_platform
import os import os
from vllm.model_executor.utils import gemm_bank_conf from vllm.model_executor.utils import gemm_bank_conf
if envs.USE_FUSED_RMS_QUANT:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
logger = init_logger(__name__) logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
...@@ -327,6 +333,7 @@ class ReplicatedLinear(LinearBase): ...@@ -327,6 +333,7 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase): ...@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase):
quant_config, quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias)
self.eps = eps
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
...@@ -385,15 +393,53 @@ class ReplicatedLinear(LinearBase): ...@@ -385,15 +393,53 @@ class ReplicatedLinear(LinearBase):
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward( def forward(
self, x: torch.Tensor self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
assert self.quant_method is not None if quant_args is not None:
output = self.quant_method.apply(self, x, bias) input_quant_args = quant_args
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: bias = self.bias if not self.skip_bias_add else None
return output assert self.quant_method is not None
return output, output_bias output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
else:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
...@@ -436,6 +482,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -436,6 +482,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None, output_sizes: Optional[list[int]] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
...@@ -459,7 +506,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -459,7 +506,7 @@ class ColumnParallelLinear(LinearBase):
quant_config, quant_config,
prefix, prefix,
return_bias=return_bias) return_bias=return_bias)
self.eps = eps
self.gather_output = gather_output self.gather_output = gather_output
if output_sizes is None: if output_sizes is None:
...@@ -543,22 +590,49 @@ class ColumnParallelLinear(LinearBase): ...@@ -543,22 +590,49 @@ class ColumnParallelLinear(LinearBase):
param.load_column_parallel_weight(loaded_weight=loaded_weight) param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
self, input_ self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. assert self.quant_method is not None
assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output:
if self.gather_output: output = tensor_model_parallel_all_gather(output_parallel)
# All-gather across the partitions. else:
output = tensor_model_parallel_all_gather(output_parallel) output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else: else:
output = output_parallel bias = self.bias if not self.skip_bias_add else None
output_bias = self.bias if self.skip_bias_add else None # Matrix multiply.
if not self.return_bias: assert self.quant_method is not None
return output output_parallel = self.quant_method.apply(self, input_, bias)
return output, output_bias if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
...@@ -593,6 +667,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -593,6 +667,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return_bias: If true, return bias together with outputs in forward pass. return_bias: If true, return bias together with outputs in forward pass.
""" """
def forward(
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else: # not USE_FUSED_RMS_QUANT
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
...@@ -602,10 +724,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -602,10 +724,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
): ):
self.eps = eps
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
...@@ -856,7 +980,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -856,7 +980,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset, shard_offset=shard_offset,
shard_size=shard_size) shard_size=shard_size)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation. """Linear layers for the attention's QKV transformation.
......
...@@ -21,6 +21,7 @@ from vllm.utils import W8a8GetCacheJSON ...@@ -21,6 +21,7 @@ from vllm.utils import W8a8GetCacheJSON
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
W8A8_TRITONJSON=W8a8GetCacheJSON() W8A8_TRITONJSON=W8a8GetCacheJSON()
...@@ -153,8 +154,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -153,8 +154,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None
): ):
x_q, x_scale = per_token_quant_int8(x) if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
m=x_q.shape[0] m=x_q.shape[0]
......
...@@ -94,11 +94,21 @@ class DeepseekV2MLP(nn.Module): ...@@ -94,11 +94,21 @@ class DeepseekV2MLP(nn.Module):
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x,
gate_up, _ = self.gate_up_proj(x) rms_weight: Optional[torch.Tensor] = None,
x = self.act_fn(gate_up) residual: Optional[torch.Tensor] = None,
x, _ = self.down_proj(x) update_hd: Optional[bool] = False
return x ):
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
...@@ -185,11 +195,17 @@ class DeepseekV2MoE(nn.Module): ...@@ -185,11 +195,17 @@ class DeepseekV2MoE(nn.Module):
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states) if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
...@@ -219,8 +235,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -219,8 +235,10 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = ( final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel( self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)) final_hidden_states))
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim), new_resi
else:
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
...@@ -421,19 +439,36 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -421,19 +439,36 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size, if envs.USE_FUSED_RMS_QUANT:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank, self.q_lora_rank,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_a_proj") prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank, self.q_b_proj = ColumnParallelLinear(q_lora_rank,
eps=config.rms_norm_eps) self.num_heads *
self.q_b_proj = ColumnParallelLinear(q_lora_rank, self.qk_head_dim,
bias=False,
quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_b_proj")
else:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads * self.num_heads *
self.qk_head_dim, self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_b_proj") prefix=f"{prefix}.q_b_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
else: else:
self.q_proj = ColumnParallelLinear(self.hidden_size, self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads * self.num_heads *
...@@ -508,31 +543,60 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -508,31 +543,60 @@ class DeepseekV2MLAAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
q_c = self.q_a_proj(hidden_states)[0] if self.q_lora_rank is not None:
q_c = self.q_a_layernorm(q_c) q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q = self.q_b_proj(q_c)[0] q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0], new_residual
else: else:
q = self.q_proj(hidden_states)[0] if self.q_lora_rank is not None:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( q_c = self.q_a_proj(hidden_states)[0]
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_c = self.q_a_layernorm(q_c)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) q = self.q_b_proj(q_c)[0]
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe # Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe) positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim)) self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0] return self.o_proj(attn_out)[0]
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
...@@ -607,47 +671,90 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -607,47 +671,90 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention if envs.USE_FUSED_RMS_QUANT:
# Fix residual FP16 overflow
# Fix residual FP16 overflow residual_fix_overflow = False
residual_fix_overflow = False
if residual is None: assert self.input_layernorm.has_weight is True
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
residual_fix_overflow = True hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
else: else:
hidden_states, residual = self.input_layernorm( # Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.mlp(hidden_states)
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: if isinstance(self.mlp,
# Fix FP16 overflow DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# We scale both hidden_states and residual before # Fix FP16 overflow
# rmsnorm, and rmsnorm result would not affect by scale. # Scaling the DeepseekV2MLP output, it is the input of
hidden_states *= 1. / self.routed_scaling_factor # input_layernorm of next decoder layer.
if self.layer_idx == 0 or residual_fix_overflow: # The scaling of DeepseekV2MOE output would be done in the forward
# The residual is shared by all layers, we only scale it on # of DeepseekV2MOE
# first layer. hidden_states *= 1. / self.routed_scaling_factor
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual return hidden_states, residual
@support_torch_compile @support_torch_compile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment