Commit fc345b74 authored by wanghl6's avatar wanghl6
Browse files

恢复误删代码

parent 153002ad
......@@ -883,6 +883,17 @@ class Indexer(nn.Module):
bias_k,
eps
)
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
if enable_lightly_cplb and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
if current_platform.is_rocm() and torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] != "gfx938":
q_fp8 = q
q_scale = None
......@@ -908,6 +919,16 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
if enable_lightly_cplb and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment