"tests/pipelines/vscode:/vscode.git/clone" did not exist on "9920c333c69d372911a4549cde7cb7cc12cd4dc8"
Commit f1b2e521 authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent 4ae9919e
...@@ -15,7 +15,7 @@ Outputs: ...@@ -15,7 +15,7 @@ Outputs:
*/ */
#pragma clang diagnostic ignored "-Wunused-variable" // TODO ANT: remove #pragma clang diagnostic ignored "-Wunused-variable" // TODO ANT: remove
#define PRINT_HOST 0 #define PRINT_HOST 0
......
...@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second; .second;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl; // LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec,
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl; // ",") << std::endl; LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ",
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl; // v_gs_os_ns_strides_vec, ",") << std::endl; LogRangeAsType<float>(std::cout <<
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl; // "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec,
// ",") << std::endl;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock), make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{}); Sequence<padder.PadN, padder.PadO>{});
...@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_, type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< c_grid_desc_g_m_n_.GetLength(I1) << ", " << c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n'; << c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print(); // c_grid_desc_g_m_n_.Print();
std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", " << vgrad_grid_desc_n_o_.GetLength(I1) << '\n'; std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", "
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0) << ", " << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
<< ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", " std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n'; << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
} }
......
...@@ -415,13 +415,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -415,13 +415,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(const LSEGridDesc_M& lse_grid_desc_m) MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(const LSEGridDesc_M& lse_grid_desc_m)
{ {
const index_t M = lse_grid_desc_m.GetLength(I0); const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock; const index_t MBlock = M / MPerBlock;
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor( const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
lse_grid_desc_m, lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))), make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3>{})); make_tuple(Sequence<0, 1, 2, 3>{}));
...@@ -469,10 +470,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -469,10 +470,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple( static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align); ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0; static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0; static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value; static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction // LDS allocation for reduction
...@@ -1015,18 +1016,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1015,18 +1016,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
1, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{ true>{p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, make_multi_index(
make_multi_index( p_thread_origin_nd_idx_on_block[I0],
p_thread_origin_nd_idx_on_block[I0], p_thread_origin_nd_idx_on_block[I1],
p_thread_origin_nd_idx_on_block[I1], p_thread_origin_nd_idx_on_block[I2] % p_block_slice_lengths_m0_n0_m1_n1[I2],
p_thread_origin_nd_idx_on_block[I2] % p_block_slice_lengths_m0_n0_m1_n1[I2], p_thread_origin_nd_idx_on_block[I3] % p_block_slice_lengths_m0_n0_m1_n1[I3],
p_thread_origin_nd_idx_on_block[I3] % p_block_slice_lengths_m0_n0_m1_n1[I3], p_thread_origin_nd_idx_on_block[I4],
p_thread_origin_nd_idx_on_block[I4], p_thread_origin_nd_idx_on_block[I5],
p_thread_origin_nd_idx_on_block[I5], p_thread_origin_nd_idx_on_block[I6],
p_thread_origin_nd_idx_on_block[I6], p_thread_origin_nd_idx_on_block[I7]),
p_thread_origin_nd_idx_on_block[I7]), tensor_operation::element_wise::PassThrough{}};
tensor_operation::element_wise::PassThrough{}};
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0], // Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_block_slice_lengths_m0_n0_m1_n1[I1], // p_block_slice_lengths_m0_n0_m1_n1[I1],
...@@ -1087,18 +1087,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1087,18 +1087,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset, static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
ygrad_block_desc_m0_o_m1.GetElementSpaceSize()); ygrad_block_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< auto vgrad_blockwise_gemm =
BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType, DataType,
FloatGemmAcc, FloatGemmAcc,
decltype(p_block_desc_m0_n_m1), decltype(p_block_desc_m0_n_m1),
decltype(ygrad_block_desc_m0_o_m1), decltype(ygrad_block_desc_m0_o_m1),
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
VGradGemmTile_N_O_M::GemmNRepeat, VGradGemmTile_N_O_M::GemmNRepeat,
VGradGemmTile_N_O_M::GemmORepeat, VGradGemmTile_N_O_M::GemmORepeat,
VGradGemmTile_N_O_M::GemmMPack, VGradGemmTile_N_O_M::GemmMPack,
true>{}; // TranspossC true>{}; // TranspossC
auto vgrad_acc_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer(); auto vgrad_acc_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
...@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor( const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
vgrad_grid_desc_n_o, vgrad_grid_desc_n_o,
make_tuple( make_tuple(
make_unmerge_transform(make_tuple(I1, make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmNWave, MPerXdl)),
VGradGemmTile_N_O_M::GemmNWave, make_unmerge_transform(make_tuple(I1, VGradGemmTile_N_O_M::GemmOWave, NPerXdl))),
MPerXdl)),
make_unmerge_transform(make_tuple(I1,
VGradGemmTile_N_O_M::GemmOWave,
NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
...@@ -1142,12 +1138,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1142,12 +1138,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t o_thread_data_idx_on_grid = const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid; vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
make_tuple(make_merge_transform( make_tuple(Sequence<0, 1, 2>{}),
make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))), make_tuple(Sequence<0>{}));
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_nd_idx_on_grid = const auto n_thread_data_nd_idx_on_grid =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
...@@ -1171,9 +1165,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1171,9 +1165,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
tensor_operation::element_wise::PassThrough, // CElementwiseOperation tensor_operation::element_wise::PassThrough, // CElementwiseOperation
decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim 7, // VectorDim
2, // ScalarPerVector 2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, true>(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
...@@ -1226,10 +1220,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1226,10 +1220,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M; constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf, lse_grid_buf,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl, lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
lse_thread_buf); lse_thread_buf);
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
...@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// load VGrad Gemm A // load VGrad Gemm A
const auto p_nd_idx = const auto p_nd_idx =
sfc_p_m0_n0_m1_n1_m2_n2.GetIndexTupleOfNumber(vgrad_gemm_loop_idx); sfc_p_m0_n0_m1_n1_m2_n2.GetIndexTupleOfNumber(vgrad_gemm_loop_idx);
constexpr auto mwave_range = make_tuple( constexpr auto mwave_range =
p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]); make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range = make_tuple( constexpr auto nwave_range =
p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]); make_tuple(p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
#if 0 #if 0
if(hipThreadIdx_x % 64 == 0) if(hipThreadIdx_x % 64 == 0)
{ {
...@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range)); p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
} }
#endif #endif
if (p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
p_thread_copy_vgpr_to_lds.Run( p_thread_copy_vgpr_to_lds.Run(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple( make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
acc_thread_buf, acc_thread_buf,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
p_block_buf); p_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment