"tests/pipelines/vscode:/vscode.git/clone" did not exist on "64cbd8e27ac9c74223ed67b141491cf902fa4836"
Commit 5ba7dc40 authored by danyao12's avatar danyao12
Browse files

add grouped pt1 qloop

parent 30716b7c
......@@ -8,6 +8,7 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_grouped_multihead_attention_backward grouped_multihead_attention_backward.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
......
......@@ -63,9 +63,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using InputDataType = BF16;
using OutputDataType = F32;
using GemmDataType = BF16;
using InputDataType = F16;
using OutputDataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
......@@ -80,7 +80,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
......@@ -95,7 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
static constexpr bool Deterministic = false;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
......@@ -511,7 +511,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_drop = 0.0;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......
......@@ -62,9 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using InputDataType = BF16;
using OutputDataType = F32;
using GemmDataType = BF16;
using InputDataType = F16;
using OutputDataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
......@@ -79,7 +79,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
......@@ -94,7 +94,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
static constexpr bool Deterministic = false;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment