Commit eac75188 authored by yuguo's avatar yuguo
Browse files

[DCU] fix merge

parent 44740c6c
......@@ -168,6 +168,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf,
cudaStream_t stream) {
#ifndef __HIP_PLATFORM_AMD__
if (cuda::sm_arch(cuda::current_device()) >= 90) {
cudaLaunchConfig_t config = {0};
int cluster_size = 8;
......@@ -193,11 +194,14 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
coeff, aux_loss, Const_buf);
} else {
#endif
size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
#ifndef __HIP_PLATFORM_AMD__
}
#endif
}
void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert,
......
......@@ -39,11 +39,19 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_
}
// Warp shuffle between threads
#ifdef __HIP_PLATFORM_AMD__
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 16, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 8, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 4, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 2, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 1, kThreadsPerWarp));
#else
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
#endif
__syncwarp();
return T(val);
}
......@@ -71,11 +79,19 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat
}
// Warp shuffle between threads
#ifdef __HIP_PLATFORM_AMD__
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 16, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 8, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 4, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 2, kThreadsPerWarp));
val = reduce_func(val, __shfl_xor_sync((unsigned long long)0xffffffff, val, 1, kThreadsPerWarp));
#else
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
#endif
__syncwarp();
return T(val);
}
......@@ -165,8 +181,13 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
}
// Warp shuffle between threads
for (int s = 16; s > 0; s /= 2) {
#ifdef __HIP_PLATFORM_AMD__
volatile auto shuffled_val = __shfl_xor_sync((unsigned long long)0xffffffff, val, s, kThreadsPerWarp);
volatile auto shuffled_index = __shfl_xor_sync((unsigned long long)0xffffffff, index, s, kThreadsPerWarp);
#else
volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
#endif
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
......
......@@ -382,7 +382,7 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t k_dim = shape.size() == 0 ? 1u : shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment