Commit 5736b460 authored by fsx950223's avatar fsx950223
Browse files

Merge branch 'my-attn-bwd2' into my-attn-bwd3

parents aace9ec6 f9bb62d5
...@@ -9,6 +9,10 @@ add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl ...@@ -9,6 +9,10 @@ add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -32,7 +32,7 @@ template <ck::index_t... Is> ...@@ -32,7 +32,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......
...@@ -16,12 +16,15 @@ struct BlockwiseDropout ...@@ -16,12 +16,15 @@ struct BlockwiseDropout
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer> template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph) __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
return keep ? val * p_dropout_rescale : float(0); if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
}; };
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
...@@ -47,6 +50,42 @@ struct BlockwiseDropout ...@@ -47,6 +50,42 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -50,7 +50,8 @@ template <index_t BlockSize, ...@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!"); "wrong!");
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
...@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
...@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr(TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -776,6 +859,21 @@ struct BlockwiseGemmXdlops_v2 ...@@ -776,6 +859,21 @@ struct BlockwiseGemmXdlops_v2
"wrong!"); "wrong!");
} }
__host__ __device__ BlockwiseGemmXdlops_v2(index_t switch_flag,
Tuple4 b_origin = CalculateBThreadOriginDataIndex(),
Tuple4 a_origin = CalculateAThreadOriginDataIndex())
: switch_flag_(switch_flag), a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{ {
...@@ -905,6 +1003,58 @@ struct BlockwiseGemmXdlops_v2 ...@@ -905,6 +1003,58 @@ struct BlockwiseGemmXdlops_v2
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr(TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -991,6 +1141,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -991,6 +1141,7 @@ struct BlockwiseGemmXdlops_v2
B_K1, B_K1,
B_K1>; B_K1>;
index_t switch_flag_;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
}; };
......
...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax ...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
}); });
} }
template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - lse_thread_buf[iM]);
});
});
}
BufferType max_value_buf; BufferType max_value_buf;
BufferType sum_value_buf; BufferType sum_value_buf;
}; };
......
...@@ -42,7 +42,7 @@ __global__ void ...@@ -42,7 +42,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
...@@ -540,7 +540,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -540,7 +540,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
......
...@@ -95,6 +95,8 @@ struct Scale ...@@ -95,6 +95,8 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
float scale_; float scale_;
}; };
......
...@@ -143,6 +143,16 @@ struct DynamicBuffer ...@@ -143,6 +143,16 @@ struct DynamicBuffer
} }
} }
__host__ __device__ void Clear()
{
static_assert(GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong! only local data share is supported");
for(index_t i = get_thread_local_1d_id(); i < element_space_size_; i += get_block_size())
{
Set(i, true, T{0});
}
}
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
...@@ -302,7 +312,9 @@ struct DynamicBuffer ...@@ -302,7 +312,9 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T"); "wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); static_assert(GetAddressSpace() == AddressSpaceEnum::Global ||
GetAddressSpace() == AddressSpaceEnum::Lds,
"only support global mem or local data share");
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
...@@ -319,7 +331,7 @@ struct DynamicBuffer ...@@ -319,7 +331,7 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing && GetAddressSpace() == AddressSpaceEnum::Global)
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
......
...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -100,6 +100,17 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, index_t N>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Number<N>& y)
{
constexpr index_t NSize = sizeof...(Xs);
// Tuple<Xs...> r;
// static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y; });
// return r;
return generate_tuple([&](auto i) { return x[i] * y; }, Number<NSize>{});
}
// MultiIndex = scalar * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment