Unverified Commit 965021d2 authored by M.Emin Ozturk's avatar M.Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into gemm_bf16_sk_muozturk

parents 8a34c640 7d8ea5f0
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -13,7 +13,6 @@ namespace conv { ...@@ -13,7 +13,6 @@ namespace conv {
struct ConvParam struct ConvParam
{ {
ConvParam();
ConvParam(ck_tile::index_t n_dim, ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count, ck_tile::index_t group_count,
ck_tile::index_t n_batch, ck_tile::index_t n_batch,
...@@ -199,11 +198,6 @@ struct ConvParam ...@@ -199,11 +198,6 @@ struct ConvParam
} }
}; };
ConvParam::ConvParam()
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{
}
CK_TILE_HOST std::string get_conv_param_parser_helper_msg() CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{ {
std::string msg; std::string msg;
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel ...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type; ck_tile::GenericAttentionMaskEnum mask_type;
}; };
struct FmhaBwdCommonDropoutKargs struct FmhaBwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset, union ValueOrPointer
const float raw_scale) {
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel ...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale; scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
}; };
struct FmhaBwdDeterministicKargs struct FmhaBwdDeterministicKargs
{ {
ck_tile::index_t split_stride_dq_acc = 0; ck_tile::index_t split_stride_dq_acc = 0;
...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_, return FmhaDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t}; kargs.p_undrop_in_uint8_t};
} }
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -170,29 +173,55 @@ struct FmhaFwdKernel ...@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_lse = 0;
}; };
struct FmhaFwdCommonDropoutKargs struct FmhaFwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max())); uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false; bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
...@@ -278,7 +307,8 @@ struct FmhaFwdKernel ...@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -344,7 +374,19 @@ struct FmhaFwdKernel ...@@ -344,7 +374,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -392,7 +434,8 @@ struct FmhaFwdKernel ...@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -455,7 +498,19 @@ struct FmhaFwdKernel ...@@ -455,7 +498,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -748,8 +803,10 @@ struct FmhaFwdKernel ...@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_, return BlockDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t, kargs.p_undrop_in_uint8_t,
kargs.is_store_randval}; kargs.is_store_randval};
......
...@@ -31,8 +31,14 @@ struct Layernorm2dFwd ...@@ -31,8 +31,14 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs struct Kargs
{ {
...@@ -96,19 +102,25 @@ struct Layernorm2dFwd ...@@ -96,19 +102,25 @@ struct Layernorm2dFwd
sequence<2>>{}); sequence<2>>{});
} }
template <typename Dstr> CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
{ {
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
using Lengths = decltype(nDstrSpan.impl_);
ck_tile::index_t ret = 1; int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
ck_tile::static_for<0, Lengths::size(), 1>{}( if(n_per_block_tail_loop > 0)
[&](auto idx) { ret *= Lengths::template at(idx); }); {
int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return ret; return max_count;
} }
template <typename DistributedTensor> template <typename DistributedTensor>
...@@ -129,42 +141,29 @@ struct Layernorm2dFwd ...@@ -129,42 +141,29 @@ struct Layernorm2dFwd
return out_dstr_tensor; return out_dstr_tensor;
} }
template <bool Cond = (kHasGamma && kHasBeta)> template <typename XBlockWindow,
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x, typename GammaBlockWindow,
const GammaDataType* p_gamma, typename BetaBlockWindow,
const BetaDataType* p_beta, typename YBlockWindow,
YDataType* p_y, typename MeanBlockWindow,
MeanDataType* p_mean, typename InvStdBlockWindow,
InvStdDataType* p_invStd, bool Cond = (kHasGamma && kHasBeta)>
const ComputeDataType epsilon, CK_TILE_DEVICE std::enable_if_t<Cond>
ck_tile::index_t M, TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
ck_tile::index_t N) const GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{ {
constexpr auto I0 = number<0>{}; // TODO - Optimize tail loop to reduce move_tile_window()
constexpr auto I1 = number<1>{}; index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
const auto gamma_n = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
const auto beta_n = make_naive_tensor_view<address_space_enum::global>( int welford_max_count = GetWelfordMaxCount(N);
p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr auto NPerThread = GetNPerThread(xDstr);
ThreadWelford<ComputeDataType, XDataType> thread_welford{
type_convert<int>(NPerThread * N / kNPerBlock)};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor = auto mean_compute_block_tensor =
...@@ -190,44 +189,14 @@ struct Layernorm2dFwd ...@@ -190,44 +189,14 @@ struct Layernorm2dFwd
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean) if constexpr(kSaveMean)
{
const auto mean_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_mean, make_tuple(M), number<32>{});
auto mean_block_window =
make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor)); store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
}
if constexpr(kSaveInvStd) if constexpr(kSaveInvStd)
{ store_tile(inv_std_block_window,
const auto inv_std_m = make_naive_tensor_view_packed<address_space_enum::global>( cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
p_invStd, make_tuple(M), number<32>{});
auto inv_std_block_window =
make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(inv_std_block_window, cast_tile<MeanDataType>(inv_std_compute_block_tensor));
}
// TODO: Extract normalize pipeline
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
// reverse read x to reuse cache // reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock}); move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window}); move_tile_window(gamma_block_window, {stride_to_right_most_window});
...@@ -274,17 +243,209 @@ struct Layernorm2dFwd ...@@ -274,17 +243,209 @@ struct Layernorm2dFwd
} }
} }
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// normalize
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(kargs.p_x), const auto x_m_n = [&]() {
static_cast<const GammaDataType*>(kargs.p_gamma), const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta), static_cast<const XDataType*>(kargs.p_x),
static_cast<YDataType*>(kargs.p_y), make_tuple(kargs.M, kargs.N),
static_cast<MeanDataType*>(kargs.p_mean), make_tuple(kargs.N, 1),
static_cast<InvStdDataType*>(kargs.p_invStd), number<kNPerThread>{},
static_cast<const ComputeDataType>(kargs.epsilon), number<1>{});
kargs.M,
kargs.N); return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
auto mean_block_window = [&]() {
if constexpr(kSaveMean)
{
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
auto inv_std_block_window = [&]() {
if constexpr(kSaveInvStd)
{
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
if(kargs.N <= kNPerBlock)
OnePassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
else
TwoPassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
} }
}; };
......
...@@ -14,17 +14,21 @@ template <typename XDataType_, ...@@ -14,17 +14,21 @@ template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename BlockShape_> typename BlockShape_,
bool kPadM_,
bool kPadN_>
struct BlockLayernorm2dFwdProblem struct BlockLayernorm2dFwdProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME) ...@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach() endforeach()
endif() endif()
if(INSTANCES_ONLY) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
# Do not build DL instances if DL_KERNELS macro is not set # Do not build DL instances if DL_KERNELS macro is not set
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
...@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME) ...@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
# Do not build mha instances if gfx94 targets are not on the target list # Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha") if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
message("removing mha instance ${source} ") message("removing mha instance ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
...@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME) ...@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
if(ARGN) if(ARGN)
set(INST_OBJ) set(INST_OBJ)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(INSTANCES_ONLY) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
if(source MATCHES "_xdl") if(source MATCHES "_xdl")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "mha") elseif(ARGN MATCHES "mha")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
endif() endif()
set(offload_targets) set(offload_targets)
foreach(target IN LISTS INST_TARGETS) foreach(target IN LISTS INST_TARGETS)
...@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list}) ...@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
set(add_inst 1) set(add_inst 1)
endif() endif()
if(INSTANCES_ONLY) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8")) if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
message("quantization instances will not be built!") message("quantization instances will not be built!")
...@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES) ...@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif() endif()
if(CK_DEVICE_MHA_INSTANCES) if(CK_DEVICE_MHA_INSTANCES)
set(gpu_list ${INST_TARGETS}) set(gpu_list ${INST_TARGETS})
list(FILTER gpu_list INCLUDE REGEX "^gfx94") if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
if(gpu_list)
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES}) add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations) add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
target_compile_features(device_mha_operations PUBLIC) target_compile_features(device_mha_operations PUBLIC)
......
...@@ -24,7 +24,7 @@ set(PROFILER_SOURCES ...@@ -24,7 +24,7 @@ set(PROFILER_SOURCES
profile_permute_scale.cpp profile_permute_scale.cpp
) )
if(GPU_TARGETS MATCHES "gfx9") if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
...@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
endif() endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
if(GPU_TARGETS MATCHES "gfx94") if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
endif() endif()
...@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endif() endif()
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
endif() endif()
...@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst ...@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
if(GPU_TARGETS MATCHES "gfx9") if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
...@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
if(GPU_TARGETS MATCHES "gfx94") if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
endif() endif()
...@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
endif() endif()
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
endif() endif()
......
...@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME) ...@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME)
endforeach() endforeach()
endif() endif()
if(INSTANCES_ONLY) set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(TEST_TARGETS ${GPU_TARGETS})
endif()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
...@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME) ...@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME)
endforeach() endforeach()
endif() endif()
if(INSTANCES_ONLY) set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(TEST_TARGETS ${GPU_TARGETS})
endif()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
...@@ -211,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange) ...@@ -211,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange)
add_subdirectory(transpose) add_subdirectory(transpose)
add_subdirectory(permute_scale) add_subdirectory(permute_scale)
add_subdirectory(wrapper) add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11") if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2 if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
add_subdirectory(smfmac_op) add_subdirectory(smfmac_op)
endif() endif()
add_subdirectory(position_embedding) add_subdirectory(position_embedding)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment