Commit 8dad40d0 authored by Anthony Chang's avatar Anthony Chang
Browse files

finish dS; not yet validated

parent ee6c4ff7
......@@ -334,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr (TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
......@@ -936,6 +988,58 @@ struct BlockwiseGemmXdlops_v2
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr (TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
......
......@@ -1346,7 +1346,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
// TODO ANT: dP Gemm can reuse first blockwise gemm and pipeline
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
......@@ -1415,7 +1414,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{});
auto pgrad_blockwise_gemm = typename PGradGemmTile_M_N_O::BlockwiseGemm{};
auto pgrad_acc_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
......@@ -1762,7 +1761,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_grid_buf);
// gemm dP
// assume size K == size O so has main block loop
// assume size K == size O so HasMainKBlockLoop is the same
block_sync_lds();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1,
......@@ -1778,7 +1777,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_block_buf, // reuse
b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm,
pgrad_acc_thread_buf,
pgrad_thread_buf,
num_o_block_main_loop);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
......@@ -1786,12 +1785,43 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
printf("j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
pgrad_acc_thread_buf[I0],
pgrad_acc_thread_buf[I1],
pgrad_acc_thread_buf[I2],
pgrad_acc_thread_buf[I3]);
pgrad_thread_buf[I0],
pgrad_thread_buf[I1],
pgrad_thread_buf[I2],
pgrad_thread_buf[I3]);
}
#endif
// calculate dS from dP
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
constexpr auto pgrad_thread_idx_to_m_n_adaptor =
pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
constexpr auto n =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I1];
// dS and P has same thread buf layout
sgrad_thread_buf(i) =
acc_thread_buf[i] * (pgrad_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}]);
});
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
sgrad_thread_buf[I0],
sgrad_thread_buf[I1],
sgrad_thread_buf[I2],
sgrad_thread_buf[I3]);
}
#endif
// move slice window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
a_block_reset_copy_step); // rewind K
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment