Commit 16428e7f authored by Anthony Chang's avatar Anthony Chang
Browse files

more test

parent 4003fefb
...@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|-------------------------------------| |-------------------------------------|
Gemm1 Gemm1
*/ */
#pragma clang diagnostic ignored "-Wunused-variable"
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
...@@ -57,7 +57,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -57,7 +57,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
...@@ -73,65 +73,65 @@ using DeviceGemmInstance = ...@@ -73,65 +73,65 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ck::half_t,
B0DataType, ck::half_t,
B1DataType, ck::half_t,
CDataType, ck::half_t,
Acc0BiasDataType, ck::Tuple<>,
Acc1BiasDataType, ck::Tuple<>,
AccDataType, float,
CShuffleDataType, float, // CShuffleDType,
AElementOp, ck::tensor_operation::element_wise::PassThrough,
B0ElementOp, ck::tensor_operation::element_wise::PassThrough,
Acc0ElementOp, ck::tensor_operation::element_wise::Scale,
B1ElementOp, ck::tensor_operation::element_wise::PassThrough,
CElementOp, ck::tensor_operation::element_wise::PassThrough,
GemmSpec, GemmSpec,
TensorSpecA, ck::tensor_operation::device::TensorSpecialization::Default,
TensorSpecB0, ck::tensor_operation::device::TensorSpecialization::Default,
TensorSpecB1, ck::tensor_operation::device::TensorSpecialization::Default,
TensorSpecC, ck::tensor_operation::device::TensorSpecialization::Default,
1, 1,
256, 256, // block_size
128, // MPerBlock 64, // m_per_block
128, // NPerBlock 256, // n_per_block
32, // KPerBlock 32, // k_per_block
64, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // ak1
8, // BK1 8, // bk1
2, // B1K1 2, // b1k1
32, // MPerXDL 16, // m_per_xdl
32, // NPerXDL 16, // n_per_xdl
1, // MXdlPerWave 1, // m_xdl_per_wave
4, // NXdlPerWave 16, // n_xdl_per_wave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ck::Sequence<4, 64, 1>, // thread_cluster_length
S<1, 0, 2>, ck::Sequence<1, 0, 2>, // thread_cluster_arrange_order
S<1, 0, 2>, ck::Sequence<1, 0, 2>, // src_access_order
2, 2, // src_vector_dim
8, 8, // src_scalar_per_vector
8, 8, // dst_scalar_per_vector
true, 1, // add_extra_dim
S<4, 64, 1>, // BBlockTransfer ck::Sequence<4, 64, 1>, // thread_cluster_length
S<1, 0, 2>, ck::Sequence<1, 0, 2>, // thread_cluster_arrange_order
S<1, 0, 2>, ck::Sequence<1, 0, 2>, // src_access_order
2, 2, // src_vector_dim
8, 8, // src_scalar_per_vector
8, 8, // dst_scalar_per_vector
true, 1, // add_extra_dim
S<16, 16, 1>, // B1BlockTransfer ck::Sequence<16, 16, 1>, // thread_cluster_length
S<0, 2, 1>, ck::Sequence<0, 2, 1>, // thread_cluster_arrange_order
S<0, 2, 1>, ck::Sequence<0, 2, 1>, // src_access_order
1, 1, // src_vector_dim
4, 4, // src_scalar_per_vector
2, 2, // dst_scalar_per_vector
false, 0, // add_extra_dim
1, // CShuffleMXdlPerWavePerShuffle 1, // m_xdl_per_wave
2, // CShuffleNXdlPerWavePerShuffle 4, // n_xdl_per_wave
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock ck::Sequence<1, 32, 1, 8>, // m_n_block_wave_per_xdl
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // scalar_per_vector
MaskingSpec>; // MaskingSpecialization ck::tensor_operation::device::MaskingSpecialization::MaskDisabled>; // causal_mask
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
...@@ -49,7 +49,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -49,7 +49,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int BatchStrideB0 = -1, int BatchStrideB0 = -1,
int BatchStrideB1 = -1, int BatchStrideB1 = -1,
int BatchStrideC = -1, int BatchStrideC = -1,
float alpha = 1.f) float alpha = -1.f)
{ {
...@@ -187,6 +187,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -187,6 +187,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
if(alpha < 0)
{
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
}
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha}; auto acc0_element_op = Acc0ElementOp{alpha};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment