Commit 6fa4feac authored by letaoqin's avatar letaoqin
Browse files

add dim=32 64

parent 9e49c2bf
......@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......@@ -151,7 +151,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -23,7 +23,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = true;
float p_drop = 0;
float p_drop = 0.1;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment