"...colossalai/train_dreambooth_colossalai.py" did not exist on "ac84c2fa5a92a0a86e351efd6e0e65f68b69815b"
Commit ad82c377 authored by carlushuang's avatar carlushuang
Browse files

fix build issue

parent 2f463a94
......@@ -75,7 +75,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk<
BlockSize,
BlockToCTileMap_GemmStreamK<MPerBlock, NPerBlock, K0PerBlock * K1>,
BlockToCTileMap_GemmStreamK<MPerBlock,
NPerBlock,
K0PerBlock * K1,
StreamKReductionStrategy::Reduction>,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
......@@ -151,8 +154,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
{
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_);
workspace_semaphore =
workspace_semaphore +
karg.block_mapping.get_workspace_size_for_acc(sizeof(GridwiseGemm::FloatAcc));
workspace_semaphore + karg.block_mapping.get_workspace_size_for_acc(
sizeof(typename GridwiseGemm::FloatAcc));
hipGetErrorString(hipMemset(
workspace_semaphore, 0, karg.block_mapping.get_workspace_size_for_semaphore()));
}
......@@ -191,7 +194,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
return p_arg->block_mapping.get_workspace_size(sizeof(GridwiseGemm::FloatAcc));
return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc));
}
else
{
......
......@@ -977,7 +977,8 @@ struct BlockToCTileMap_GemmStreamK
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv& eqav_tiles_) const
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
const MDiv& eqav_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
......
......@@ -556,8 +556,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number<NPerBlock>{},
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto acc_thread_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
// constexpr auto acc_thread_buf_desc = make_naive_tensor_descriptor_packed(
// make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{},
Number<1>{},
Number<1>{},
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n = make_multi_index(
......@@ -610,31 +620,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
FloatAcc, // SrcData,
FloatAcc, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_desc), // DstDesc,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder,
2, // SrcVectorDim,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(static_cast<index_t>(tile_acc_offset_start), I0, I0)};
>{c_partial_acc_block_m_n, make_multi_index(0, 0)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, // SrcData,
FloatC, // DstData,
decltype(acc_thread_buf_desc), // SrcDesc,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder,
Sequence<0, 0, 0, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
2, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
3, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(spatial_idx[I0], I0, spatial_idx[I1], I0),
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
0,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
0),
CElementwiseOperation{}};
// block synchronization
......@@ -653,21 +665,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_desc,
make_multi_index(I0),
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_desc.CalculateOffset(make_tuple(i_vec));
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
acc_store.Run(acc_thread_buf_desc,
make_multi_index(I0),
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
......@@ -678,7 +691,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_load_step_n);
partial_acc_store_step_n);
}
else
{
......@@ -964,7 +977,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
FloatCShuffle, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......@@ -1030,11 +1043,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// constexpr offset
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave, I0, nxdlperwave, I0));
make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_block_buf),
decltype(c_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
......
......@@ -549,7 +549,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
{
static_assert(
(is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
......@@ -602,6 +602,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
0);
}
}
else if constexpr(is_same<T, half_t>::value)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment