"tests/vscode:/vscode.git/clone" did not exist on "5e704a2c71a7c2cc819b2311d9b6c35e6bfe6797"
Commit 337d6703 authored by fsx950223's avatar fsx950223
Browse files

merge updates

parents de43a6d8 27482328
...@@ -9,6 +9,8 @@ add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl ...@@ -9,6 +9,8 @@ add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -50,7 +50,8 @@ template <index_t BlockSize, ...@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!"); "wrong!");
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
...@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
...@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr(TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -776,6 +859,21 @@ struct BlockwiseGemmXdlops_v2 ...@@ -776,6 +859,21 @@ struct BlockwiseGemmXdlops_v2
"wrong!"); "wrong!");
} }
__host__ __device__ BlockwiseGemmXdlops_v2(index_t switch_flag,
Tuple4 b_origin = CalculateBThreadOriginDataIndex(),
Tuple4 a_origin = CalculateAThreadOriginDataIndex())
: switch_flag_(switch_flag), a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{ {
...@@ -905,6 +1003,58 @@ struct BlockwiseGemmXdlops_v2 ...@@ -905,6 +1003,58 @@ struct BlockwiseGemmXdlops_v2
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr(TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -991,6 +1141,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -991,6 +1141,7 @@ struct BlockwiseGemmXdlops_v2
B_K1, B_K1,
B_K1>; B_K1>;
index_t switch_flag_;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
}; };
......
...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax ...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
}); });
} }
template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - lse_thread_buf[iM]);
});
});
}
BufferType max_value_buf; BufferType max_value_buf;
BufferType sum_value_buf; BufferType sum_value_buf;
}; };
......
...@@ -42,7 +42,7 @@ __global__ void ...@@ -42,7 +42,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
...@@ -540,7 +540,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -540,7 +540,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
......
...@@ -143,6 +143,16 @@ struct DynamicBuffer ...@@ -143,6 +143,16 @@ struct DynamicBuffer
} }
} }
__host__ __device__ void Clear()
{
static_assert(GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong! only local data share is supported");
for(index_t i = get_thread_local_1d_id(); i < element_space_size_; i += get_block_size())
{
Set(i, true, T{0});
}
}
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
...@@ -302,7 +312,9 @@ struct DynamicBuffer ...@@ -302,7 +312,9 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T"); "wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); static_assert(GetAddressSpace() == AddressSpaceEnum::Global ||
GetAddressSpace() == AddressSpaceEnum::Lds,
"only support global mem or local data share");
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
...@@ -319,7 +331,7 @@ struct DynamicBuffer ...@@ -319,7 +331,7 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing && GetAddressSpace() == AddressSpaceEnum::Global)
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
......
...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, index_t N>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Number<N>& y)
{
constexpr index_t NSize = sizeof...(Xs);
// Tuple<Xs...> r;
// static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y; });
// return r;
return generate_tuple([&](auto i) { return x[i] * y; }, Number<NSize>{});
}
// MultiIndex = scalar * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
......
...@@ -19,4 +19,37 @@ struct ThisThreadBlock ...@@ -19,4 +19,37 @@ struct ThisThreadBlock
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
}; };
template <index_t ThreadPerBlock>
struct SubThreadBlock
{
static constexpr index_t kNumThread_ = ThreadPerBlock;
__device__ SubThreadBlock(int mwave, int nwave) : mwave_(mwave), nwave_(nwave) {}
__device__ static constexpr index_t GetNumOfThread() { return kNumThread_; }
template <typename TupleArg1, typename TupleArg2>
__device__ constexpr bool IsBelong(const TupleArg1& mwave_range, const TupleArg2& nwave_range)
{
// wave_range[I0] inclusive, wave_range[I1] exclusive
if(mwave_ < mwave_range[I0])
return false;
else if(mwave_ >= mwave_range[I1])
return false;
else if(nwave_ < nwave_range[I0])
return false;
else if(nwave_ >= nwave_range[I1])
return false;
else
return true;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
private:
index_t mwave_, nwave_;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
};
} // namespace ck } // namespace ck
...@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator
ck::type_convert<AccDataType>( ck::type_convert<AccDataType>(
arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) + arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) +
arg.beta_ * self(idx); arg.beta_ * self(idx);
// printf(
// "exponent %f, exp() = %f\n",
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx))),
// std::exp(
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))));
}); });
return 0; return 0;
......
...@@ -11,6 +11,25 @@ ...@@ -11,6 +11,25 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
/*
For fp16 M-contigous matrix of size M_K, each thread reads 4x2 tile (2 * 64bits) from the global
memory, transposes the 4x2 tile inside register, and writes into LDS in K0_M_K1 layout. This allows
us to use 128-bit LDS write instruction. This also avoids write bank conflicts because two
vertically connected 4x2 tiles is a contiguous chunk of memory if modeled as K0_M_K1 layout where
K1=2.
<- K1 -> <- K1 -> <- K1 ->
_________ _________ _________
| | 0 | 4 | transpose | 0 - 1 | to LDS | 0 - 1 |
| | 1 | 5 | ---> | 2 - 3 | ----> | 2 - 3 |
| | 2 | 6 | | 4 - 5 | | 4 - 5 |
M | | 3 | 7 | | 6 - 7 | | 6 - 7 |
| --------- --------- ---------
| | ... | | ... | | ... |
v --------- --------- ---------
VMEM VGPR LDS
*/
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -58,3 +58,4 @@ add_subdirectory(batchnorm) ...@@ -58,3 +58,4 @@ add_subdirectory(batchnorm)
if(GPU_TARGETS MATCHES "gfx1100") if(GPU_TARGETS MATCHES "gfx1100")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
add_subdirectory(host_tensor)
add_gtest_executable(test_host_tensor test_host_tensor.cpp)
target_link_libraries(test_host_tensor PRIVATE utility)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment