Unverified Commit 82dec1f7 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Remove redundant type conversion (#4513)

parent 5f9b2c62
...@@ -21,7 +21,8 @@ concurrency: ...@@ -21,7 +21,8 @@ concurrency:
jobs: jobs:
accuracy-test-1-gpu-amd: accuracy-test-1-gpu-amd:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
runs-on: linux-mi300-gpu-1 runs-on: linux-mi300-gpu-1
steps: steps:
- name: Checkout code - name: Checkout code
...@@ -60,7 +61,8 @@ jobs: ...@@ -60,7 +61,8 @@ jobs:
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
mla-test-1-gpu-amd: mla-test-1-gpu-amd:
if: github.event.pull_request.head.repo.fork == false && github.event.pull_request.draft == false if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
runs-on: linux-mi300-gpu-1 runs-on: linux-mi300-gpu-1
steps: steps:
- name: Checkout code - name: Checkout code
...@@ -97,6 +99,7 @@ jobs: ...@@ -97,6 +99,7 @@ jobs:
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py
finish: finish:
if: always()
needs: [ needs: [
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd accuracy-test-1-gpu-amd, mla-test-1-gpu-amd
] ]
......
...@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
global_override_indptr_cpu = None global_override_indptr_cpu = None
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros( kv_indices = torch.empty(
( (
self.speculative_num_steps, self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len, forward_batch.batch_size * self.topk * self.max_context_len,
......
...@@ -84,7 +84,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -84,7 +84,7 @@ class TritonAttnBackend(AttentionBackend):
if spec_info is None: if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros( kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
...@@ -100,7 +100,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -100,7 +100,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
attn_logits = torch.zeros( attn_logits = torch.empty(
( (
bs, bs,
self.num_head, self.num_head,
...@@ -127,7 +127,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -127,7 +127,7 @@ class TritonAttnBackend(AttentionBackend):
# Different with flashinfer kv_indptr and kv_indices construction # Different with flashinfer kv_indptr and kv_indices construction
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros( kv_indices = torch.empty(
kv_indptr[-1], dtype=torch.int32, device=self.device kv_indptr[-1], dtype=torch.int32, device=self.device
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
...@@ -166,7 +166,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -166,7 +166,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.extend_prefix_lens, dim=0 forward_batch.extend_prefix_lens, dim=0
) )
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros( kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(), forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
...@@ -531,7 +531,7 @@ class TritonMultiStepDraftBackend: ...@@ -531,7 +531,7 @@ class TritonMultiStepDraftBackend:
call_fn(i, forward_batch) call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros( kv_indices = torch.empty(
( (
self.speculative_num_steps, self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len, forward_batch.batch_size * self.topk * self.max_context_len,
......
...@@ -168,7 +168,7 @@ class Sampler(nn.Module): ...@@ -168,7 +168,7 @@ class Sampler(nn.Module):
group=self.tp_sync_group, group=self.tp_sync_group,
) )
return batch_next_token_ids.to(torch.int32) return batch_next_token_ids
def _apply_custom_logit_processor( def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
......
...@@ -69,7 +69,7 @@ class TpModelWorkerClient: ...@@ -69,7 +69,7 @@ class TpModelWorkerClient:
self.future_token_ids_ct = 0 self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3 self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty( self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device (self.max_running_requests * 5,), dtype=torch.int64, device=self.device
) )
# Launch threads # Launch threads
......
...@@ -44,6 +44,9 @@ class TestUpdateWeightsFromTensor(unittest.TestCase): ...@@ -44,6 +44,9 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
def test_update_weights_from_tensor(self): def test_update_weights_from_tensor(self):
tp_sizes = [1, 2] tp_sizes = [1, 2]
for tp_size in tp_sizes: for tp_size in tp_sizes:
if torch.cuda.device_count() < tp_size:
continue
with self.subTest(tp_size=tp_size): with self.subTest(tp_size=tp_size):
test_update_weights_from_tensor(tp_size) test_update_weights_from_tensor(tp_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment