Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8ecffc73
Commit
8ecffc73
authored
Feb 16, 2026
by
Rayyyyy
Browse files
Fix sampler.cu cub
parent
843c1822
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
9 deletions
+22
-9
csrc/sampler.cu
csrc/sampler.cu
+18
-5
tests/models/registry.py
tests/models/registry.py
+4
-4
No files found.
csrc/sampler.cu
View file @
8ecffc73
...
@@ -215,7 +215,11 @@ __device__ bool processHistogramStep(
...
@@ -215,7 +215,11 @@ __device__ bool processHistogramStep(
// Compute the prefix sum.
// Compute the prefix sum.
int
prefixSum
{
0
},
totalSum
{
0
};
int
prefixSum
{
0
},
totalSum
{
0
};
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#ifndef USE_ROCM
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#else:
using
Scan
=
hipcub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#endif
Scan
(
smemFinal
.
histo
.
scan
).
ExclusiveSum
(
binCount
,
prefixSum
,
totalSum
);
Scan
(
smemFinal
.
histo
.
scan
).
ExclusiveSum
(
binCount
,
prefixSum
,
totalSum
);
// Update the histogram with the prefix sums.
// Update the histogram with the prefix sums.
...
@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits,
...
@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits,
static
constexpr
int
kNumFinalItemsPerThread
=
static
constexpr
int
kNumFinalItemsPerThread
=
kNumFinalItems
/
kNumThreadsPerBlock
;
kNumFinalItems
/
kNumThreadsPerBlock
;
// The class to sort the elements during the final pass.
// The class to sort the elements during the final pass.
using
FinalSort
=
cub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
#ifndef USE_ROCM
kNumFinalItemsPerThread
,
int
>
;
using
FinalSort
=
cub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
kNumFinalItemsPerThread
,
int
>
;
#else
using
FinalSort
=
hipcub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
kNumFinalItemsPerThread
,
int
>
;
#endif
using
FinalSortTempStorage
=
using
FinalSortTempStorage
=
std
::
conditional_t
<
useRadixSort
,
typename
FinalSort
::
TempStorage
,
int
>
;
std
::
conditional_t
<
useRadixSort
,
typename
FinalSort
::
TempStorage
,
int
>
;
// The class to compute the inclusive prefix-sum over the histogram.
// The class to compute the inclusive prefix-sum over the histogram.
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#ifndef USE_ROCM
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#else
using
Scan
=
hipcub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#endif
// The structure to store the final items (for the final pass).
// The structure to store the final items (for the final pass).
struct
FinalItems
{
struct
FinalItems
{
// Shared memory to store the indices for the final pass.
// Shared memory to store the indices for the final pass.
...
...
tests/models/registry.py
View file @
8ecffc73
...
@@ -944,22 +944,22 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -944,22 +944,22 @@ _MULTIMODAL_EXAMPLE_MODELS = {
min_transformers_version
=
"4.57"
,
min_transformers_version
=
"4.57"
,
),
),
"Qwen3_5ForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen3_5ForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen3.5-9B-Instruct"
,
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-9B-Instruct"
)
,
max_model_len
=
4096
,
max_model_len
=
4096
,
min_transformers_version
=
"5.1.0"
,
min_transformers_version
=
"5.1.0"
,
),
),
"Qwen3_5MoeForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen3_5MoeForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen3.5-35B-A3B-Instruct"
,
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-35B-A3B-Instruct"
)
,
max_model_len
=
4096
,
max_model_len
=
4096
,
min_transformers_version
=
"5.1.0"
,
min_transformers_version
=
"5.1.0"
,
),
),
"Qwen3_5MTP"
:
_HfExamplesInfo
(
"Qwen3_5MTP"
:
_HfExamplesInfo
(
"Qwen/Qwen3.5-9B-Instruct"
,
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-9B-Instruct"
)
,
speculative_model
=
"Qwen/Qwen3.5-9B-Instruct"
,
speculative_model
=
"Qwen/Qwen3.5-9B-Instruct"
,
min_transformers_version
=
"5.1.0"
,
min_transformers_version
=
"5.1.0"
,
),
),
"Qwen3_5MoeMTP"
:
_HfExamplesInfo
(
"Qwen3_5MoeMTP"
:
_HfExamplesInfo
(
"Qwen/Qwen3.5-35B-A3B-Instruct"
,
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-35B-A3B-Instruct"
)
,
speculative_model
=
"Qwen/Qwen3.5-35B-A3B-Instruct"
,
speculative_model
=
"Qwen/Qwen3.5-35B-A3B-Instruct"
,
min_transformers_version
=
"5.1.0"
,
min_transformers_version
=
"5.1.0"
,
),
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment