Commit ad82c377 authored by carlushuang's avatar carlushuang
Browse files

fix build issue

parent 2f463a94
...@@ -75,7 +75,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -75,7 +75,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk<
BlockSize, BlockSize,
BlockToCTileMap_GemmStreamK<MPerBlock, NPerBlock, K0PerBlock * K1>, BlockToCTileMap_GemmStreamK<MPerBlock,
NPerBlock,
K0PerBlock * K1,
StreamKReductionStrategy::Reduction>,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
...@@ -151,8 +154,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -151,8 +154,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
{ {
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_); char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_);
workspace_semaphore = workspace_semaphore =
workspace_semaphore + workspace_semaphore + karg.block_mapping.get_workspace_size_for_acc(
karg.block_mapping.get_workspace_size_for_acc(sizeof(GridwiseGemm::FloatAcc)); sizeof(typename GridwiseGemm::FloatAcc));
hipGetErrorString(hipMemset( hipGetErrorString(hipMemset(
workspace_semaphore, 0, karg.block_mapping.get_workspace_size_for_semaphore())); workspace_semaphore, 0, karg.block_mapping.get_workspace_size_for_semaphore()));
} }
...@@ -191,7 +194,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -191,7 +194,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction) StreamKReductionStrategy::Reduction)
{ {
return p_arg->block_mapping.get_workspace_size(sizeof(GridwiseGemm::FloatAcc)); return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc));
} }
else else
{ {
......
...@@ -977,7 +977,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -977,7 +977,8 @@ struct BlockToCTileMap_GemmStreamK
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
} }
__device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv& eqav_tiles_) const __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
const MDiv& eqav_tiles_) const
{ {
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1; uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
......
...@@ -556,8 +556,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -556,8 +556,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number<NPerBlock>{}, Number<NPerBlock>{},
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto acc_thread_buf_desc = make_naive_tensor_descriptor_packed( // constexpr auto acc_thread_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); // make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{},
Number<1>{},
Number<1>{},
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n = make_multi_index( constexpr auto partial_acc_load_step_n = make_multi_index(
...@@ -607,34 +617,36 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -607,34 +617,36 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1); block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
auto acc_load = ThreadwiseTensorSliceTransfer_v2< auto acc_load = ThreadwiseTensorSliceTransfer_v2<
FloatAcc, // SrcData, FloatAcc, // SrcData,
FloatAcc, // DstData, FloatAcc, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc, decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_desc), // DstDesc, decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder, Sequence<0, 1>, // DimAccessOrder,
2, // SrcVectorDim, 1, // SrcVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
1, // SrcScalarStrideInVector, 1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun, false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n, >{c_partial_acc_block_m_n, make_multi_index(0, 0)};
make_multi_index(static_cast<index_t>(tile_acc_offset_start), I0, I0)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, // SrcData, FloatAcc, // SrcData,
FloatC, // DstData, FloatC, // DstData,
decltype(acc_thread_buf_desc), // SrcDesc, decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, Sequence<0, 0, 0, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder, Sequence<0, 1, 2, 3>, // DimAccessOrder,
2, // DstVectorDim, 2, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector, 3, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun, false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock, >{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(spatial_idx[I0], I0, spatial_idx[I1], I0), make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
0,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
0),
CElementwiseOperation{}}; CElementwiseOperation{}};
// block synchronization // block synchronization
...@@ -653,21 +665,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -653,21 +665,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
acc_load.Run(c_partial_acc_block_m_n, acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf, c_partial_acc_buf,
acc_thread_buf_desc, acc_thread_buf_load_desc,
make_multi_index(I0), make_tuple(I0, I0),
parcial_acc_buf); parcial_acc_buf);
static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}( static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}(
[&](auto i_vec) { [&](auto i_vec) {
constexpr auto offset = constexpr auto offset =
acc_thread_buf_desc.CalculateOffset(make_tuple(i_vec)); acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}), Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]); parcial_acc_buf[Number<offset>{}]);
}); });
} }
acc_store.Run(acc_thread_buf_desc, acc_store.Run(acc_thread_buf_store_desc,
make_multi_index(I0), make_tuple(I0, I0, I0, I0),
acc_buf, acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
...@@ -678,7 +691,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -678,7 +691,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
partial_acc_load_step_n); partial_acc_load_step_n);
acc_store.MoveDstSliceWindow( acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_load_step_n); partial_acc_store_step_n);
} }
else else
{ {
...@@ -964,7 +977,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -964,7 +977,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
FloatC, // typename DstData, FloatCShuffle, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
...@@ -1030,11 +1043,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1030,11 +1043,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// constexpr offset // constexpr offset
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave, I0, nxdlperwave, I0)); make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
c_block_copy_lds_to_partial_acc c_block_copy_lds_to_partial_acc
.template Run<decltype(c_block_buf), .template Run<decltype(c_block_buf),
decltype(c_block_buf), decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>( InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf, c_block_buf,
......
...@@ -549,7 +549,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -549,7 +549,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
{ {
static_assert( static_assert(
(is_same<T, double>::value && (N == 1 || N == 2)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
...@@ -602,6 +602,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -602,6 +602,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 8)
{
vector_type<float, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
0);
}
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(is_same<T, half_t>::value)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment