"...composable_kernel.git" did not exist on "aafc3ac27a4d448b728a241fd6072005b87df22f"
Commit da2bce29 authored by Anthony Chang's avatar Anthony Chang
Browse files

revert accidental example code changes

parent 16428e7f
...@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -8,7 +8,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|-------------------------------------| |-------------------------------------|
Gemm1 Gemm1
*/ */
#pragma clang diagnostic ignored "-Wunused-variable"
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
...@@ -57,7 +57,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -57,7 +57,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
...@@ -73,65 +73,65 @@ using DeviceGemmInstance = ...@@ -73,65 +73,65 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ck::half_t, ADataType,
ck::half_t, B0DataType,
ck::half_t, B1DataType,
ck::half_t, CDataType,
ck::Tuple<>, Acc0BiasDataType,
ck::Tuple<>, Acc1BiasDataType,
float, AccDataType,
float, // CShuffleDType, CShuffleDataType,
ck::tensor_operation::element_wise::PassThrough, AElementOp,
ck::tensor_operation::element_wise::PassThrough, B0ElementOp,
ck::tensor_operation::element_wise::Scale, Acc0ElementOp,
ck::tensor_operation::element_wise::PassThrough, B1ElementOp,
ck::tensor_operation::element_wise::PassThrough, CElementOp,
GemmSpec, GemmSpec,
ck::tensor_operation::device::TensorSpecialization::Default, TensorSpecA,
ck::tensor_operation::device::TensorSpecialization::Default, TensorSpecB0,
ck::tensor_operation::device::TensorSpecialization::Default, TensorSpecB1,
ck::tensor_operation::device::TensorSpecialization::Default, TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1, 1,
256, // block_size 4,
64, // m_per_block 2,
256, // n_per_block false,
32, // k_per_block 1, // CShuffleMXdlPerWavePerShuffle
64, // Gemm1NPerBlock 2, // CShuffleNXdlPerWavePerShuffle
32, // Gemm1KPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // ak1 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
8, // bk1 MaskingSpec>; // MaskingSpecialization
2, // b1k1
16, // m_per_xdl
16, // n_per_xdl
1, // m_xdl_per_wave
16, // n_xdl_per_wave
4, // Gemm1NXdlPerWave
ck::Sequence<4, 64, 1>, // thread_cluster_length
ck::Sequence<1, 0, 2>, // thread_cluster_arrange_order
ck::Sequence<1, 0, 2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<4, 64, 1>, // thread_cluster_length
ck::Sequence<1, 0, 2>, // thread_cluster_arrange_order
ck::Sequence<1, 0, 2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<16, 16, 1>, // thread_cluster_length
ck::Sequence<0, 2, 1>, // thread_cluster_arrange_order
ck::Sequence<0, 2, 1>, // src_access_order
1, // src_vector_dim
4, // src_scalar_per_vector
2, // dst_scalar_per_vector
0, // add_extra_dim
1, // m_xdl_per_wave
4, // n_xdl_per_wave
ck::Sequence<1, 32, 1, 8>, // m_n_block_wave_per_xdl
8, // scalar_per_vector
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled>; // causal_mask
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment