Unverified Commit 11b23ae9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Remove extra copy in deepseek forward absorb (#5578)


Co-authored-by: default avatarsaienduri <saimanas.enduri@amd.com>
parent b9c87e78
...@@ -38,12 +38,12 @@ jobs: ...@@ -38,12 +38,12 @@ jobs:
else else
DEVICE_FLAG="--device /dev/dri" DEVICE_FLAG="--device /dev/dri"
fi fi
docker pull lmsysorg/sglang:v0.4.5-rocm630 docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \ -w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5-rocm630 lmsysorg/sglang:v0.4.5.post2-rocm630
- name: Install dependencies - name: Install dependencies
run: | run: |
...@@ -82,12 +82,12 @@ jobs: ...@@ -82,12 +82,12 @@ jobs:
else else
DEVICE_FLAG="--device /dev/dri" DEVICE_FLAG="--device /dev/dri"
fi fi
docker pull lmsysorg/sglang:v0.4.5-rocm630 docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \ -w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5-rocm630 lmsysorg/sglang:v0.4.5.post2-rocm630
- name: Install dependencies - name: Install dependencies
run: | run: |
...@@ -120,12 +120,12 @@ jobs: ...@@ -120,12 +120,12 @@ jobs:
else else
DEVICE_FLAG="--device /dev/dri" DEVICE_FLAG="--device /dev/dri"
fi fi
docker pull lmsysorg/sglang:v0.4.5-rocm630 docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \ docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \ -v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \ -w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5-rocm630 lmsysorg/sglang:v0.4.5.post2-rocm630
- name: Install dependencies - name: Install dependencies
run: | run: |
...@@ -149,7 +149,7 @@ jobs: ...@@ -149,7 +149,7 @@ jobs:
finish: finish:
if: always() if: always()
needs: [ needs: [
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd accuracy-test-1-gpu-amd, mla-test-1-gpu-amd, bench-test-2-gpu-amd
] ]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
......
...@@ -665,6 +665,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -665,6 +665,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
dtype = query.dtype
query_rot = query[..., : self.rotary_dim] query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size: if self.rotary_dim < self.head_size:
...@@ -695,7 +696,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -695,7 +696,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
else: else:
query = query_rot query = query_rot
key = key_rot key = key_rot
return query, key return query.to(dtype), key.to(dtype)
class Llama3RotaryEmbedding(RotaryEmbedding): class Llama3RotaryEmbedding(RotaryEmbedding):
......
...@@ -682,10 +682,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -682,10 +682,6 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
)
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
...@@ -729,20 +725,20 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -729,20 +725,20 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
else: else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
q_nope_out = q_nope_out.transpose(0, 1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
v_input = latent_cache[..., : self.kv_lora_rank] k_nope = latent_cache[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
k_input = latent_cache.unsqueeze(1) k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
k_input[..., : self.kv_lora_rank] = v_input
k_pe = k_input[..., self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm: if self.use_deep_gemm_bmm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment