"example/01_gemm/gemm_wmma_int8.cpp" did not exist on "fa2d894be1b3c0213da06d58af0df2de2c5308ad"
Commit ac580f77 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into migx-jit-lib

parents 707d6261 027e46ee
...@@ -136,7 +136,7 @@ __global__ void ...@@ -136,7 +136,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -685,7 +685,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -685,7 +685,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
return false; return false;
} }
} }
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940") else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" ||
get_device_name() == "gfx941" || get_device_name() == "gfx942")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>)) is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__) || defined(__gfx940__)) defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
...@@ -39,7 +39,7 @@ __global__ void ...@@ -39,7 +39,7 @@ __global__ void
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
...@@ -35,7 +35,7 @@ __global__ void ...@@ -35,7 +35,7 @@ __global__ void
const index_t group_count) const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
...@@ -80,7 +80,8 @@ template <typename FloatAB, ...@@ -80,7 +80,8 @@ template <typename FloatAB,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> int D0sTransferSrcScalarPerVector = 4,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -621,13 +622,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -621,13 +622,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
I1, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
I1, // MWaveId m1, // MWaveId
I1, // NWaveId n1, // NWaveId
I1, // MPerXdl m2, // MPerXdl
I1, // NGroupNum n2, // NGroupNum
I1, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
auto d0s_thread_buf = generate_tuple( auto d0s_thread_buf = generate_tuple(
...@@ -644,9 +645,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -644,9 +645,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, n2, n4));
auto d0s_threadwise_copy = generate_tuple( auto d0s_threadwise_copy = generate_tuple(
[&](auto i) { [&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>; using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
...@@ -655,10 +653,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -655,10 +653,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
D0DataType, D0DataType,
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>, Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, 9,
n4, D0sTransferSrcScalarPerVector,
1, 1,
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx[I0], // MBlockId
...@@ -884,62 +891,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -884,62 +891,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// multiple d // multiple d
if constexpr(NumD0Tensor) if constexpr(NumD0Tensor)
{ {
static_for<0, MXdlPerWave, 1>{}([&](auto mr) { static_assert(NXdlPerWave == n0);
static_for<0, NXdlPerWave, 1>{}([&](auto nr) { static_assert(MXdlPerWave == m0);
static_for<0, n2, 1>{}([&](auto groupid) {
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run( d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_grid_buf[i],
d0s_grid_buf[i], d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), d0s_thread_buf(i));
d0s_thread_buf(i)); });
}); static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
// get reference to src data
static_for<0, n4, 1>{}([&](auto i) { const auto src_data_refs = generate_tie(
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( // return type should be lvalue
make_tuple(mr, nr, groupid, i)); [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
Number<NumD0Tensor>{});
// get reference to src data
const auto src_data_refs = generate_tie( // get reference to dst data
// return type should be lvalue auto dst_data_refs = generate_tie(
[&](auto iSrc) -> const auto& { // return type should be lvalue
return d0s_thread_buf[iSrc][i]; [&](auto) -> auto& { return acc_thread_buf(i); },
}, Number<2>{});
Number<NumD0Tensor>{});
unpack2(c0de_element_op, dst_data_refs, src_data_refs);
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& {
return acc_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(c0de_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
});
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 0, 1, -NXdlPerWave, 0, 0, 0, 0, 0, 0));
});
}); });
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
make_multi_index(0, 1, -MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}); });
} }
else else
......
...@@ -67,7 +67,7 @@ __global__ void ...@@ -67,7 +67,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -55,7 +55,7 @@ __global__ void ...@@ -55,7 +55,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940)) defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
...@@ -46,7 +46,7 @@ __global__ void ...@@ -46,7 +46,7 @@ __global__ void
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
......
...@@ -58,7 +58,7 @@ __global__ void ...@@ -58,7 +58,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue // TODO ANT: separate into MMA + Epilogue
......
...@@ -166,7 +166,7 @@ __global__ void ...@@ -166,7 +166,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -45,7 +45,7 @@ __global__ void ...@@ -45,7 +45,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
......
...@@ -36,7 +36,7 @@ __global__ void ...@@ -36,7 +36,7 @@ __global__ void
const CGridDesc_M_N c_grid_desc_m_n) const CGridDesc_M_N c_grid_desc_m_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_k0_m_k1 = const auto a_grid_desc_k0_m_k1 =
......
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......
...@@ -31,7 +31,7 @@ __global__ void ...@@ -31,7 +31,7 @@ __global__ void
const Block2CTileMap& b2c_map) const Block2CTileMap& b2c_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
...@@ -47,7 +47,7 @@ __global__ void ...@@ -47,7 +47,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
......
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -54,7 +54,7 @@ __global__ void ...@@ -54,7 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -286,7 +286,22 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -286,7 +286,22 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
template <typename T, index_t N> // memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target
// e.g. for
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
// page 67~68
enum struct AmdBufferCoherenceEnum
{
DefaultCoherence = 0, // default value
GLC = 1,
SLC = 2,
GLC_SLC = 3,
};
template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
...@@ -305,28 +320,37 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -305,28 +320,37 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
// use fp32 load to mimic fp64 load // use fp32 load to mimic fp64 load
if constexpr(N == 1) if constexpr(N == 1)
{ {
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( const float2_t tmp =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double>(tmp); return bit_cast<double>(tmp);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( const float4_t tmp =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double2_t>(tmp); return bit_cast<double2_t>(tmp);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( const float4_t f32_0 =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
const float4_t f32_1 = const float4_t f32_1 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float), src_wave_addr_offset + 4 * sizeof(float),
0); static_cast<index_t>(coherence));
vector_type<double, 4> tmp; vector_type<double, 4> tmp;
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0); tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
...@@ -339,31 +363,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -339,31 +363,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_fp32( return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_fp32x2( return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_fp32x4( return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.AsType<float4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( tmp.AsType<float4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<float4_t>()(Number<1>{}) = tmp.AsType<float4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float), src_wave_addr_offset + 4 * sizeof(float),
0); static_cast<index_t>(coherence));
return tmp.AsType<float8_t>()(Number<0>{}); return tmp.AsType<float8_t>()(Number<0>{});
} }
...@@ -372,24 +405,32 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -372,24 +405,32 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_fp16( return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_fp16x2( return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_fp16x4( return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
// use fp32 load to mimic fp16 load // use fp32 load to mimic fp16 load
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<half8_t>(tmp); return bit_cast<half8_t>(tmp);
} }
...@@ -398,23 +439,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -398,23 +439,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_i16( return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_i16x2( return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_i16x4( return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<bhalf8_t>(tmp); return bit_cast<bhalf8_t>(tmp);
} }
...@@ -423,31 +472,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -423,31 +472,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_i32( return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_i32x2( return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_i32x4( return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<int32_t, 8> tmp; vector_type<int32_t, 8> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( tmp.AsType<int32x4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp.AsType<int32x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t), src_wave_addr_offset + 4 * sizeof(int32_t),
0); static_cast<index_t>(coherence));
return tmp.AsType<int32x8_t>()(Number<0>{}); return tmp.AsType<int32x8_t>()(Number<0>{});
} }
} }
...@@ -455,17 +513,23 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -455,17 +513,23 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_i8( return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x2( return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else #else
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x2_t>(tmp); return bit_cast<int8x2_t>(tmp);
#endif #endif
...@@ -473,11 +537,15 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -473,11 +537,15 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x4( return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else #else
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x4_t>(tmp); return bit_cast<int8x4_t>(tmp);
#endif #endif
...@@ -487,19 +555,24 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -487,19 +555,24 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp; vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<1>{}) = tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t), src_wave_addr_offset + 4 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
return tmp.AsType<int8x8_t>()(Number<0>{}); return tmp.AsType<int8x8_t>()(Number<0>{});
#else #else
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x8_t>(tmp); return bit_cast<int8x8_t>(tmp);
#endif #endif
...@@ -509,31 +582,36 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -509,31 +582,36 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp; vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) =
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<1>{}) = tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t), src_wave_addr_offset + 4 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<2>{}) = tmp.AsType<int8x4_t>()(Number<2>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t), src_wave_addr_offset + 8 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<3>{}) = tmp.AsType<int8x4_t>()(Number<3>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t), src_wave_addr_offset + 12 * sizeof(int8_t),
0); static_cast<index_t>(coherence));
return tmp.AsType<int8x16_t>()(Number<0>{}); return tmp.AsType<int8x16_t>()(Number<0>{});
#else #else
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x16_t>(tmp); return bit_cast<int8x16_t>(tmp);
#endif #endif
...@@ -541,7 +619,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -541,7 +619,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
} }
} }
template <typename T, index_t N> template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data, __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
...@@ -565,7 +645,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -565,7 +645,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -573,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -573,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, float>::value) else if constexpr(is_same<T, float>::value)
...@@ -584,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -584,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -592,7 +672,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -592,7 +672,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -600,7 +680,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -600,7 +680,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(is_same<T, half_t>::value)
...@@ -611,7 +691,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -611,7 +691,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -619,7 +699,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -619,7 +699,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -627,7 +707,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -627,7 +707,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
...@@ -638,19 +718,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -638,19 +718,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t), dst_wave_addr_offset + 4 * sizeof(half_t),
0); static_cast<index_t>(coherence));
#else #else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#endif #endif
} }
} }
...@@ -662,7 +742,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -662,7 +742,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -670,7 +750,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -670,7 +750,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -678,7 +758,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -678,7 +758,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
...@@ -688,13 +768,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -688,13 +768,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bhalf_t), dst_wave_addr_offset + 4 * sizeof(bhalf_t),
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -705,7 +785,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -705,7 +785,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -713,7 +793,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -713,7 +793,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
...@@ -721,7 +801,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -721,7 +801,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, int8_t>::value) else if constexpr(is_same<T, int8_t>::value)
...@@ -732,7 +812,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -732,7 +812,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
...@@ -741,13 +821,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -741,13 +821,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#else #else
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#endif #endif
} }
else if constexpr(N == 4) else if constexpr(N == 4)
...@@ -757,13 +837,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -757,13 +837,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#else #else
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data), llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
#endif #endif
} }
else if constexpr(N == 8) else if constexpr(N == 8)
...@@ -772,7 +852,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -772,7 +852,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
...@@ -780,7 +860,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -780,7 +860,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
} }
} }
} }
...@@ -1012,7 +1092,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::typ ...@@ -1012,7 +1092,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::typ
// 1) p_src_wave must point to global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
...@@ -1032,10 +1114,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1032,10 +1114,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else #else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
...@@ -1046,7 +1128,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1046,7 +1128,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
// 1) p_src_wave must point to global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
...@@ -1064,7 +1148,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -1064,7 +1148,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(customized_value); return src_thread_element_valid ? tmp : vector_t(customized_value);
...@@ -1074,7 +1158,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -1074,7 +1158,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data, __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave, T* p_dst_wave,
const index_t dst_thread_element_offset, const index_t dst_thread_element_offset,
...@@ -1093,12 +1179,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1093,12 +1179,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
amd_buffer_store_impl<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
} }
#endif #endif
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/utility/functional2.hpp" #include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include <array>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
...@@ -14,29 +15,83 @@ ...@@ -14,29 +15,83 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <unsigned Size> template <unsigned SizeInBytes>
struct get_unsigned_int; struct get_carrier;
template <> template <>
struct get_unsigned_int<1> struct get_carrier<1>
{ {
using type = uint8_t; using type = uint8_t;
}; };
template <> template <>
struct get_unsigned_int<2> struct get_carrier<2>
{ {
using type = uint16_t; using type = uint16_t;
}; };
template <> template <>
struct get_unsigned_int<4> struct get_carrier<3>
{
using type = class carrier
{
using value_type = uint32_t;
std::array<std::byte, 3> bytes;
static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n()
template <typename InputIterator, typename Size, typename OutputIterator>
__device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to)
{
if(0 < size)
{
*to = *from;
++to;
for(Size count = 1; count < size; ++count)
{
*to = *++from;
++to;
}
}
return to;
}
// method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept
{
copy_n(other.bytes.begin(), bytes.size(), bytes.begin());
}
public:
__device__ carrier& operator=(value_type value) noexcept
{
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin());
return *this;
}
__device__ operator value_type() const noexcept
{
std::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result);
return *reinterpret_cast<const value_type*>(result);
}
};
};
static_assert(sizeof(get_carrier<3>::type) == 3);
template <>
struct get_carrier<4>
{ {
using type = uint32_t; using type = uint32_t;
}; };
template <unsigned Size> template <unsigned SizeInBytes>
using get_unsigned_int_t = typename get_unsigned_int<Size>::type; using get_carrier_t = typename get_carrier<SizeInBytes>::type;
} // namespace detail } // namespace detail
...@@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) ...@@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj)
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize)
{ {
using Sgpr = detail::get_unsigned_int_t<SgprSize>; using Sgpr = detail::get_carrier_t<SgprSize>;
*reinterpret_cast<Sgpr*>(to_obj + offset) = *reinterpret_cast<Sgpr*>(to_obj + offset) =
amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + offset)); amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + offset));
...@@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) ...@@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj)
if constexpr(0 < RemainedSize) if constexpr(0 < RemainedSize)
{ {
using Carrier = detail::get_unsigned_int_t<RemainedSize>; using Carrier = detail::get_carrier_t<RemainedSize>;
*reinterpret_cast<Carrier>(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( *reinterpret_cast<Carrier*>(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane(
*reinterpret_cast<const Carrier*>(from_obj + CompleteSgprCopyBoundary)); *reinterpret_cast<const Carrier*>(from_obj + CompleteSgprCopyBoundary));
} }
......
...@@ -344,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -344,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx90a__) || defined(__gfx940__) #if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment