Commit b637c77d authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent 8dad40d0
...@@ -340,28 +340,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -340,28 +340,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(), GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths()); GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve< return SpaceFillingCurve<
decltype(c_thread_lengths), decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type, typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved false>{}; // SnakeCurved
} }
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D() __host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{ {
if constexpr (TransposeC) if constexpr(TransposeC)
{ {
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{}); constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{}); constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{}); constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{}); constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{}); constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{}); constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{}); constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{}); constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)), make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))), make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor; return thread_idx_to_m_n_adaptor;
...@@ -369,17 +369,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -369,17 +369,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
else else
{ {
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{}); constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{}); constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{}); constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{}); constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{}); constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{}); constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{}); constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{}); constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)), make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))), make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor; return thread_idx_to_m_n_adaptor;
...@@ -1002,20 +1002,20 @@ struct BlockwiseGemmXdlops_v2 ...@@ -1002,20 +1002,20 @@ struct BlockwiseGemmXdlops_v2
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D() __host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{ {
if constexpr (TransposeC) if constexpr(TransposeC)
{ {
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{}); constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{}); constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{}); constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{}); constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{}); constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{}); constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{}); constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{}); constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)), make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))), make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor; return thread_idx_to_m_n_adaptor;
...@@ -1023,17 +1023,17 @@ struct BlockwiseGemmXdlops_v2 ...@@ -1023,17 +1023,17 @@ struct BlockwiseGemmXdlops_v2
else else
{ {
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{}); constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{}); constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{}); constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{}); constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{}); constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{}); constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{}); constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{}); constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)), make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))), make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor; return thread_idx_to_m_n_adaptor;
......
...@@ -487,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -487,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto M1 = ygrad_grid_desc_m0_o_m1.GetLength(I2); const auto M1 = ygrad_grid_desc_m0_o_m1.GetLength(I2);
constexpr auto Y_O1 = AK1; constexpr auto Y_O1 = AK1;
const auto Y_O0 = O / Y_O1; const auto Y_O0 = O / Y_O1;
const auto ygrad_grid_desc_o0_m_o1 = transform_tensor_descriptor( const auto ygrad_grid_desc_o0_m_o1 = transform_tensor_descriptor(
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
...@@ -508,7 +508,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -508,7 +508,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2); const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2);
constexpr auto V_O1 = BK1; constexpr auto V_O1 = BK1;
const auto V_O0 = O / V_O1; const auto V_O0 = O / V_O1;
const auto v_grid_desc_o0_n_o1 = transform_tensor_descriptor( const auto v_grid_desc_o0_n_o1 = transform_tensor_descriptor(
v_grid_desc_n0_o_n1, v_grid_desc_n0_o_n1,
...@@ -1414,7 +1414,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1414,7 +1414,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto pgrad_blockwise_gemm = typename PGradGemmTile_M_N_O::BlockwiseGemm{}; auto pgrad_blockwise_gemm = typename PGradGemmTile_M_N_O::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer(); auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step = const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0); make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step = const auto pgrad_gemm_tile_v_block_reset_copy_step =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment