"platforms/vscode:/vscode.git/clone" did not exist on "043377fd4491c63a9d7b8dcb271985e4a6469c54"
Commit a05d749e authored by wujl5's avatar wujl5 Committed by zhangzbb
Browse files

[BUGFIX] rms_quant融合功能适配DSA

parent 456e8c10
......@@ -271,7 +271,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
bias: torch.Tensor | None = None, **_
) -> torch.Tensor:
if self.use_llama_nn:
# if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
......@@ -458,11 +458,15 @@ class ReplicatedLinear(LinearBase):
def forward(
self,
x: torch.Tensor,
*,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
if envs.USE_FUSED_RMS_QUANT and iqis is not None and iqis[0] is not None:
output = self.quant_method.apply(self, x, bias, input_quant_args=iqis)
else:
output = self.quant_method.apply(self, x, bias)
if not self.return_bias:
return output
......
......@@ -177,9 +177,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(
hidden_states, q_c, positions, self.indexer_rope_emb
)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb, iqis=iqis)
else:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb)
if llama_4_scaling is not None:
q *= llama_4_scaling
......
......@@ -730,15 +730,18 @@ class Indexer(nn.Module):
)
def forward(
self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
k, _ = self.wk(hidden_states)
if envs.USE_FUSED_RMS_QUANT and self.wk.weight.dtype == torch.int8 and iqis is not None:
k, _ = self.wk(hidden_states, iqis=iqis)
else:
k, _ = self.wk(hidden_states)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
......@@ -770,7 +773,10 @@ class Indexer(nn.Module):
else:
q_fp8 = q
weights, _ = self.weights_proj(hidden_states)
if envs.USE_FUSED_RMS_QUANT and self.weights_proj.weight.dtype == torch.int8 and iqis is not None:
weights, _ = self.weights_proj(hidden_states, iqis=iqis)
else:
weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
......@@ -1073,19 +1079,21 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fix residual FP16 overflow
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
# DSA should set update_input True
_dsa_flag = hasattr(self.self_attn, "indexer") and self.self_attn.indexer is not None
if residual is None:
residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None,
quant_dtype=torch.int8,
update_input=False
update_input=_dsa_flag
)
residual_fix_overflow = True
else:
i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=False
update_input=_dsa_flag
)
attn_kwargs = {
"positions": positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment