Commit f1b2e521 authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent 4ae9919e
...@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second; .second;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl; // LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec,
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl; // ",") << std::endl; LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ",
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl; // v_gs_os_ns_strides_vec, ",") << std::endl; LogRangeAsType<float>(std::cout <<
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl; // "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec,
// ",") << std::endl;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock), make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{}); Sequence<padder.PadN, padder.PadO>{});
...@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_, type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< c_grid_desc_g_m_n_.GetLength(I1) << ", " << c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n'; << c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print(); // c_grid_desc_g_m_n_.Print();
std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", " << vgrad_grid_desc_n_o_.GetLength(I1) << '\n'; std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", "
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0) << ", " << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
<< ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", " std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n'; << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
} }
......
...@@ -421,7 +421,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -421,7 +421,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor( const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
lse_grid_desc_m, lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))), make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3>{})); make_tuple(Sequence<0, 1, 2, 3>{}));
...@@ -1015,8 +1016,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1015,8 +1016,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
1, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{ true>{p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index( make_multi_index(
p_thread_origin_nd_idx_on_block[I0], p_thread_origin_nd_idx_on_block[I0],
p_thread_origin_nd_idx_on_block[I1], p_thread_origin_nd_idx_on_block[I1],
...@@ -1087,8 +1087,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1087,8 +1087,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset, static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
ygrad_block_desc_m0_o_m1.GetElementSpaceSize()); ygrad_block_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< auto vgrad_blockwise_gemm =
BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType, DataType,
FloatGemmAcc, FloatGemmAcc,
decltype(p_block_desc_m0_n_m1), decltype(p_block_desc_m0_n_m1),
...@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor( const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
vgrad_grid_desc_n_o, vgrad_grid_desc_n_o,
make_tuple( make_tuple(
make_unmerge_transform(make_tuple(I1, make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmNWave, MPerXdl)),
VGradGemmTile_N_O_M::GemmNWave, make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmOWave, NPerXdl))),
MPerXdl)),
make_unmerge_transform(make_tuple(I1,
VGradGemmTile_N_O_M::GemmOWave,
NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
...@@ -1142,10 +1138,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1142,10 +1138,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t o_thread_data_idx_on_grid = const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid; vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
make_tuple(make_merge_transform(
make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// load VGrad Gemm A // load VGrad Gemm A
const auto p_nd_idx = const auto p_nd_idx =
sfc_p_m0_n0_m1_n1_m2_n2.GetIndexTupleOfNumber(vgrad_gemm_loop_idx); sfc_p_m0_n0_m1_n1_m2_n2.GetIndexTupleOfNumber(vgrad_gemm_loop_idx);
constexpr auto mwave_range = make_tuple( constexpr auto mwave_range =
p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]); make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range = make_tuple( constexpr auto nwave_range =
p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]); make_tuple(p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
#if 0 #if 0
if(hipThreadIdx_x % 64 == 0) if(hipThreadIdx_x % 64 == 0)
{ {
...@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range)); p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
} }
#endif #endif
if (p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
p_thread_copy_vgpr_to_lds.Run( p_thread_copy_vgpr_to_lds.Run(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple( make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
acc_thread_buf, acc_thread_buf,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_buf); p_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment