Commit 94177eb6 authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-dropout-v2

parents 44f4498a 71e2a917
...@@ -98,9 +98,9 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecia ...@@ -98,9 +98,9 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecia
static constexpr bool Deterministic = false; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1.
// If 32 < DIM <= 64 , ues prototype1 2nd template. // If 32 < DIM <= 64 , ues prototype1.
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2.
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
......
...@@ -62,9 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -62,9 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using InputDataType = BF16; using InputDataType = F16;
using OutputDataType = F32; using OutputDataType = F16;
using GemmDataType = BF16; using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -79,7 +79,7 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -79,7 +79,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4 // When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4; static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -97,9 +97,9 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecia ...@@ -97,9 +97,9 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecia
static constexpr bool Deterministic = false; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1.
// If 32 < DIM <= 64 , ues prototype1 2nd template. // If 32 < DIM <= 64 , ues prototype1.
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2.
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment