Commit da2bce29 authored by Anthony Chang's avatar Anthony Chang
Browse files

revert accidental example code changes

parent 16428e7f
......@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|-------------------------------------|
Gemm1
*/
#pragma clang diagnostic ignored "-Wunused-variable"
#include <iostream>
#include <numeric>
#include <initializer_list>
......@@ -57,7 +57,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......@@ -73,65 +73,65 @@ using DeviceGemmInstance =
NumDimN,
NumDimK,
NumDimO,
ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
ck::Tuple<>,
ck::Tuple<>,
float,
float, // CShuffleDType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
ck::tensor_operation::device::TensorSpecialization::Default,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256, // block_size
64, // m_per_block
256, // n_per_block
32, // k_per_block
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // ak1
8, // bk1
2, // b1k1
16, // m_per_xdl
16, // n_per_xdl
1, // m_xdl_per_wave
16, // n_xdl_per_wave
4, // Gemm1NXdlPerWave
ck::Sequence<4, 64, 1>, // thread_cluster_length
ck::Sequence<1, 0, 2>, // thread_cluster_arrange_order
ck::Sequence<1, 0, 2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<4, 64, 1>, // thread_cluster_length
ck::Sequence<1, 0, 2>, // thread_cluster_arrange_order
ck::Sequence<1, 0, 2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<16, 16, 1>, // thread_cluster_length
ck::Sequence<0, 2, 1>, // thread_cluster_arrange_order
ck::Sequence<0, 2, 1>, // src_access_order
1, // src_vector_dim
4, // src_scalar_per_vector
2, // dst_scalar_per_vector
0, // add_extra_dim
1, // m_xdl_per_wave
4, // n_xdl_per_wave
ck::Sequence<1, 32, 1, 8>, // m_n_block_wave_per_xdl
8, // scalar_per_vector
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled>; // causal_mask
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment