Commit 8ecffc73 authored by Rayyyyy's avatar Rayyyyy
Browse files

Fix sampler.cu cub

parent 843c1822
...@@ -215,7 +215,11 @@ __device__ bool processHistogramStep( ...@@ -215,7 +215,11 @@ __device__ bool processHistogramStep(
// Compute the prefix sum. // Compute the prefix sum.
int prefixSum{0}, totalSum{0}; int prefixSum{0}, totalSum{0};
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>; #ifndef USE_ROCM
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
#else:
using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>;
#endif
Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum); Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum);
// Update the histogram with the prefix sums. // Update the histogram with the prefix sums.
...@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits, ...@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits,
static constexpr int kNumFinalItemsPerThread = static constexpr int kNumFinalItemsPerThread =
kNumFinalItems / kNumThreadsPerBlock; kNumFinalItems / kNumThreadsPerBlock;
// The class to sort the elements during the final pass. // The class to sort the elements during the final pass.
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock, #ifndef USE_ROCM
kNumFinalItemsPerThread, int>; using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumFinalItemsPerThread, int>;
#else
using FinalSort = hipcub::BlockRadixSort<float, kNumThreadsPerBlock,
kNumFinalItemsPerThread, int>;
#endif
using FinalSortTempStorage = using FinalSortTempStorage =
std::conditional_t<useRadixSort, typename FinalSort::TempStorage, int>; std::conditional_t<useRadixSort, typename FinalSort::TempStorage, int>;
// The class to compute the inclusive prefix-sum over the histogram. // The class to compute the inclusive prefix-sum over the histogram.
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>; #ifndef USE_ROCM
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
#else
using Scan = hipcub::BlockScan<int, kNumThreadsPerBlock>;
#endif
// The structure to store the final items (for the final pass). // The structure to store the final items (for the final pass).
struct FinalItems { struct FinalItems {
// Shared memory to store the indices for the final pass. // Shared memory to store the indices for the final pass.
......
...@@ -944,22 +944,22 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -944,22 +944,22 @@ _MULTIMODAL_EXAMPLE_MODELS = {
min_transformers_version="4.57", min_transformers_version="4.57",
), ),
"Qwen3_5ForConditionalGeneration": _HfExamplesInfo( "Qwen3_5ForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct", os.path.join(models_path_prefix, "Qwen/Qwen3.5-9B-Instruct"),
max_model_len=4096, max_model_len=4096,
min_transformers_version="5.1.0", min_transformers_version="5.1.0",
), ),
"Qwen3_5MoeForConditionalGeneration": _HfExamplesInfo( "Qwen3_5MoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct", os.path.join(models_path_prefix, "Qwen/Qwen3.5-35B-A3B-Instruct"),
max_model_len=4096, max_model_len=4096,
min_transformers_version="5.1.0", min_transformers_version="5.1.0",
), ),
"Qwen3_5MTP": _HfExamplesInfo( "Qwen3_5MTP": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct", os.path.join(models_path_prefix, "Qwen/Qwen3.5-9B-Instruct"),
speculative_model="Qwen/Qwen3.5-9B-Instruct", speculative_model="Qwen/Qwen3.5-9B-Instruct",
min_transformers_version="5.1.0", min_transformers_version="5.1.0",
), ),
"Qwen3_5MoeMTP": _HfExamplesInfo( "Qwen3_5MoeMTP": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct", os.path.join(models_path_prefix, "Qwen/Qwen3.5-35B-A3B-Instruct"),
speculative_model="Qwen/Qwen3.5-35B-A3B-Instruct", speculative_model="Qwen/Qwen3.5-35B-A3B-Instruct",
min_transformers_version="5.1.0", min_transformers_version="5.1.0",
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment