Commit f1b2e521 authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent 4ae9919e
......@@ -15,7 +15,7 @@ Outputs:
*/
#pragma clang diagnostic ignored "-Wunused-variable" // TODO ANT: remove
#pragma clang diagnostic ignored "-Wunused-variable" // TODO ANT: remove
#define PRINT_HOST 0
......
......@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec,
// ",") << std::endl; LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ",
// v_gs_os_ns_strides_vec, ",") << std::endl; LogRangeAsType<float>(std::cout <<
// "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec,
// ",") << std::endl;
return PadTensorDescriptor(vgrad_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
......@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_, type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
......@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print();
std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", " << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", "
<< vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
}
......
......@@ -415,13 +415,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3>{}));
......@@ -469,10 +470,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction
......@@ -1015,18 +1016,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(
p_thread_origin_nd_idx_on_block[I0],
p_thread_origin_nd_idx_on_block[I1],
p_thread_origin_nd_idx_on_block[I2] % p_block_slice_lengths_m0_n0_m1_n1[I2],
p_thread_origin_nd_idx_on_block[I3] % p_block_slice_lengths_m0_n0_m1_n1[I3],
p_thread_origin_nd_idx_on_block[I4],
p_thread_origin_nd_idx_on_block[I5],
p_thread_origin_nd_idx_on_block[I6],
p_thread_origin_nd_idx_on_block[I7]),
tensor_operation::element_wise::PassThrough{}};
true>{p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(
p_thread_origin_nd_idx_on_block[I0],
p_thread_origin_nd_idx_on_block[I1],
p_thread_origin_nd_idx_on_block[I2] % p_block_slice_lengths_m0_n0_m1_n1[I2],
p_thread_origin_nd_idx_on_block[I3] % p_block_slice_lengths_m0_n0_m1_n1[I3],
p_thread_origin_nd_idx_on_block[I4],
p_thread_origin_nd_idx_on_block[I5],
p_thread_origin_nd_idx_on_block[I6],
p_thread_origin_nd_idx_on_block[I7]),
tensor_operation::element_wise::PassThrough{}};
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_block_slice_lengths_m0_n0_m1_n1[I1],
......@@ -1087,18 +1087,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
ygrad_block_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockSize,
DataType,
FloatGemmAcc,
decltype(p_block_desc_m0_n_m1),
decltype(ygrad_block_desc_m0_o_m1),
MPerXdl,
NPerXdl,
VGradGemmTile_N_O_M::GemmNRepeat,
VGradGemmTile_N_O_M::GemmORepeat,
VGradGemmTile_N_O_M::GemmMPack,
true>{}; // TranspossC
auto vgrad_blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
FloatGemmAcc,
decltype(p_block_desc_m0_n_m1),
decltype(ygrad_block_desc_m0_o_m1),
MPerXdl,
NPerXdl,
VGradGemmTile_N_O_M::GemmNRepeat,
VGradGemmTile_N_O_M::GemmORepeat,
VGradGemmTile_N_O_M::GemmMPack,
true>{}; // TranspossC
auto vgrad_acc_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
......@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
vgrad_grid_desc_n_o,
make_tuple(
make_unmerge_transform(make_tuple(I1,
VGradGemmTile_N_O_M::GemmNWave,
MPerXdl)),
make_unmerge_transform(make_tuple(I1,
VGradGemmTile_N_O_M::GemmOWave,
NPerXdl))),
make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmNWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmOWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
......@@ -1142,12 +1138,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_nd_idx_on_grid =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
......@@ -1171,9 +1165,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
......@@ -1226,10 +1220,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
lse_thread_buf);
lse_grid_buf,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
lse_thread_buf);
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
......@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// load VGrad Gemm A
const auto p_nd_idx =
sfc_p_m0_n0_m1_n1_m2_n2.GetIndexTupleOfNumber(vgrad_gemm_loop_idx);
constexpr auto mwave_range = make_tuple(
p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range = make_tuple(
p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
constexpr auto mwave_range =
make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range =
make_tuple(p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
#if 0
if(hipThreadIdx_x % 64 == 0)
{
......@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
}
#endif
if (p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
p_thread_copy_vgpr_to_lds.Run(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(
p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
acc_thread_buf,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment