Commit 8ecffc73 authored by Rayyyyy's avatar Rayyyyy
Browse files

Fix sampler.cu cub

parent 843c1822
......@@ -215,7 +215,11 @@ __device__ bool processHistogramStep(
// Compute the prefix sum.
int prefixSum{0}, totalSum{0};
#ifndef USE_ROCM
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
#else:
using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>;
#endif
Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum);
// Update the histogram with the prefix sums.
......@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits,
static constexpr int kNumFinalItemsPerThread =
kNumFinalItems / kNumThreadsPerBlock;
// The class to sort the elements during the final pass.
#ifndef USE_ROCM
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumFinalItemsPerThread, int>;
#else
using FinalSort = hipcub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumFinalItemsPerThread, int>;
#endif
using FinalSortTempStorage =
std::conditional_t<useRadixSort, typename FinalSort::TempStorage, int>;
// The class to compute the inclusive prefix-sum over the histogram.
#ifndef USE_ROCM
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
#else
using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>;
#endif
// The structure to store the final items (for the final pass).
struct FinalItems {
// Shared memory to store the indices for the final pass.
......
......@@ -944,22 +944,22 @@ _MULTIMODAL_EXAMPLE_MODELS = {
min_transformers_version="4.57",
),
"Qwen3_5ForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct",
os.path.join(models_path_prefix, "Qwen/Qwen3.5-9B-Instruct"),
max_model_len=4096,
min_transformers_version="5.1.0",
),
"Qwen3_5MoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct",
os.path.join(models_path_prefix, "Qwen/Qwen3.5-35B-A3B-Instruct"),
max_model_len=4096,
min_transformers_version="5.1.0",
),
"Qwen3_5MTP": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct",
os.path.join(models_path_prefix, "Qwen/Qwen3.5-9B-Instruct"),
speculative_model="Qwen/Qwen3.5-9B-Instruct",
min_transformers_version="5.1.0",
),
"Qwen3_5MoeMTP": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct",
os.path.join(models_path_prefix, "Qwen/Qwen3.5-35B-A3B-Instruct"),
speculative_model="Qwen/Qwen3.5-35B-A3B-Instruct",
min_transformers_version="5.1.0",
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment