Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82dec1f7
Unverified
Commit
82dec1f7
authored
Mar 17, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 17, 2025
Browse files
Remove redundant type conversion (#4513)
parent
5f9b2c62
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
10 deletions
+16
-10
.github/workflows/pr-test-amd.yml
.github/workflows/pr-test-amd.yml
+5
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+1
-1
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+5
-5
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+1
-1
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+1
-1
test/srt/test_update_weights_from_tensor.py
test/srt/test_update_weights_from_tensor.py
+3
-0
No files found.
.github/workflows/pr-test-amd.yml
View file @
82dec1f7
...
@@ -21,7 +21,8 @@ concurrency:
...
@@ -21,7 +21,8 @@ concurrency:
jobs
:
jobs
:
accuracy-test-1-gpu-amd
:
accuracy-test-1-gpu-amd
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft ==
false
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft ==
false
runs-on
:
linux-mi300-gpu-1
runs-on
:
linux-mi300-gpu-1
steps
:
steps
:
-
name
:
Checkout code
-
name
:
Checkout code
...
@@ -60,7 +61,8 @@ jobs:
...
@@ -60,7 +61,8 @@ jobs:
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
mla-test-1-gpu-amd
:
mla-test-1-gpu-amd
:
if
:
github.event.pull_request.head.repo.fork ==
false
&& github.event.pull_request.draft ==
false
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft ==
false
runs-on
:
linux-mi300-gpu-1
runs-on
:
linux-mi300-gpu-1
steps
:
steps
:
-
name
:
Checkout code
-
name
:
Checkout code
...
@@ -97,6 +99,7 @@ jobs:
...
@@ -97,6 +99,7 @@ jobs:
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py
finish
:
finish
:
if
:
always()
needs
:
[
needs
:
[
accuracy-test-1-gpu-amd
,
mla-test-1-gpu-amd
accuracy-test-1-gpu-amd
,
mla-test-1-gpu-amd
]
]
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
82dec1f7
...
@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
global_override_indptr_cpu
=
None
global_override_indptr_cpu
=
None
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
empty
(
(
(
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
82dec1f7
...
@@ -84,7 +84,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -84,7 +84,7 @@ class TritonAttnBackend(AttentionBackend):
if
spec_info
is
None
:
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
@@ -100,7 +100,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -100,7 +100,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
torch
.
zeros
(
attn_logits
=
torch
.
empty
(
(
(
bs
,
bs
,
self
.
num_head
,
self
.
num_head
,
...
@@ -127,7 +127,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -127,7 +127,7 @@ class TritonAttnBackend(AttentionBackend):
# Different with flashinfer kv_indptr and kv_indices construction
# Different with flashinfer kv_indptr and kv_indices construction
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
@@ -166,7 +166,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -166,7 +166,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
extend_prefix_lens
,
dim
=
0
forward_batch
.
extend_prefix_lens
,
dim
=
0
)
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
empty
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -531,7 +531,7 @@ class TritonMultiStepDraftBackend:
...
@@ -531,7 +531,7 @@ class TritonMultiStepDraftBackend:
call_fn
(
i
,
forward_batch
)
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
empty
(
(
(
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
...
...
python/sglang/srt/layers/sampler.py
View file @
82dec1f7
...
@@ -168,7 +168,7 @@ class Sampler(nn.Module):
...
@@ -168,7 +168,7 @@ class Sampler(nn.Module):
group
=
self
.
tp_sync_group
,
group
=
self
.
tp_sync_group
,
)
)
return
batch_next_token_ids
.
to
(
torch
.
int32
)
return
batch_next_token_ids
def
_apply_custom_logit_processor
(
def
_apply_custom_logit_processor
(
self
,
logits
:
torch
.
Tensor
,
sampling_batch_info
:
SamplingBatchInfo
self
,
logits
:
torch
.
Tensor
,
sampling_batch_info
:
SamplingBatchInfo
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
82dec1f7
...
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
...
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_map
=
torch
.
empty
(
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int
32
,
device
=
self
.
device
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int
64
,
device
=
self
.
device
)
)
# Launch threads
# Launch threads
...
...
test/srt/test_update_weights_from_tensor.py
View file @
82dec1f7
...
@@ -44,6 +44,9 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
...
@@ -44,6 +44,9 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
def
test_update_weights_from_tensor
(
self
):
def
test_update_weights_from_tensor
(
self
):
tp_sizes
=
[
1
,
2
]
tp_sizes
=
[
1
,
2
]
for
tp_size
in
tp_sizes
:
for
tp_size
in
tp_sizes
:
if
torch
.
cuda
.
device_count
()
<
tp_size
:
continue
with
self
.
subTest
(
tp_size
=
tp_size
):
with
self
.
subTest
(
tp_size
=
tp_size
):
test_update_weights_from_tensor
(
tp_size
)
test_update_weights_from_tensor
(
tp_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment