Commit eac75188 authored by yuguo's avatar yuguo
Browse files

[DCU] fix merge

parent 44740c6c
...@@ -168,6 +168,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, ...@@ -168,6 +168,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
int num_cols, int topk, float coeff, int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf, DataType* aux_loss, float* Const_buf,
cudaStream_t stream) { cudaStream_t stream) {
#ifndef __HIP_PLATFORM_AMD__
if (cuda::sm_arch(cuda::current_device()) >= 90) { if (cuda::sm_arch(cuda::current_device()) >= 90) {
cudaLaunchConfig_t config = {0}; cudaLaunchConfig_t config = {0};
int cluster_size = 8; int cluster_size = 8;
...@@ -193,11 +194,14 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, ...@@ -193,11 +194,14 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
coeff, aux_loss, Const_buf); coeff, aux_loss, Const_buf);
} else { } else {
#endif
size_t smem_size = sizeof(CompType) * num_cols; size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType> fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf); num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
#ifndef __HIP_PLATFORM_AMD__
} }
#endif
} }
void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert, void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert,
......
...@@ -39,11 +39,19 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_ ...@@ -39,11 +39,19 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_
} }
// Warp shuffle between threads // Warp shuffle between threads
#ifdef __HIP_PLATFORM_AMD__
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 16, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 8, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 4, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 2, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 1, kThreadsPerWarp));
#else
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
#endif
__syncwarp(); __syncwarp();
return T(val); return T(val);
} }
...@@ -71,11 +79,19 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat ...@@ -71,11 +79,19 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat
} }
// Warp shuffle between threads // Warp shuffle between threads
#ifdef __HIP_PLATFORM_AMD__
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 16, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 8, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 4, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 2, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 1, kThreadsPerWarp));
#else
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1)); val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
#endif
__syncwarp(); __syncwarp();
return T(val); return T(val);
} }
...@@ -165,8 +181,13 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i ...@@ -165,8 +181,13 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
} }
// Warp shuffle between threads // Warp shuffle between threads
for (int s = 16; s > 0; s /= 2) { for (int s = 16; s > 0; s /= 2) {
#ifdef __HIP_PLATFORM_AMD__
volatile auto shuffled_val = __shfl_xor_sync((unsigned long long)0xffffffff, val, s, kThreadsPerWarp);
volatile auto shuffled_index = __shfl_xor_sync((unsigned long long)0xffffffff, index, s, kThreadsPerWarp);
#else
volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
#endif
if (shuffled_val > val) { if (shuffled_val > val) {
val = shuffled_val; val = shuffled_val;
index = shuffled_index; index = shuffled_index;
......
...@@ -382,7 +382,7 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size ...@@ -382,7 +382,7 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t k_dim = shape.size() == 0 ? 1u : shape.back(); size_t k_dim = shape.size() == 0 ? 1u : shape.back();
size_t m_dim = numel / k_dim; size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128; size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
Float8BlockScaleTensorFormat data_format = Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment