Unverified Commit 4a51d2da authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Fix grouped_gemm_splitk kernels on MI300. (#694)



* replace amd_buffer_atomic_add with hip_atomic_add

* fix grouped_gemm_splitk kernels on mi300

* fix syntax

* revert experimental atomic_add changes

---------
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 86e0190e
...@@ -147,7 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -147,7 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#else #else
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
c_tensors_device[i]->SetZero(); c_tensors_device[i]->SetZero();
#endif #endif
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
......
...@@ -34,7 +34,8 @@ __global__ void ...@@ -34,7 +34,8 @@ __global__ void
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment