Commit 0e877b84 authored by Chao Liu's avatar Chao Liu
Browse files

adding thread group

parent 3f4af14c
...@@ -26,17 +26,14 @@ ...@@ -26,17 +26,14 @@
#endif #endif
#endif #endif
// buffer resourse, wave size // buffer resourse
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_GPU_WAVE_SIZE -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_GPU_WAVE_SIZE 64
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32
#endif #endif
// FMA instruction // FMA instruction
......
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP #pragma once
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
...@@ -8,7 +6,7 @@ ...@@ -8,7 +6,7 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <typename ThreadGroup,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
...@@ -25,7 +23,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -25,7 +23,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
...@@ -53,7 +51,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -53,7 +51,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = ThreadGroup::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
...@@ -120,8 +118,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -120,8 +118,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(ThreadGroup::GetNumOfThread() == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "ThreadGroup::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!"); "wrong!");
...@@ -337,4 +335,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -337,4 +335,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
}; };
} // namespace ck } // namespace ck
#endif
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "transpose_vectors.hpp" #include "transpose_vectors.hpp"
#include "inner_product.hpp" #include "inner_product.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "thread_group.hpp"
#include "debug.hpp" #include "debug.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
......
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
namespace ck { namespace ck {
__device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; } __host__ __device__ constexpr index_t get_warp_size()
{ // warpSize is defined by HIP
return warpSize;
}
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); } __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment