Commit 6fa4feac authored by letaoqin's avatar letaoqin
Browse files

add dim=32 64

parent 9e49c2bf
...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1 Gemm1
*/ */
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false; ...@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -151,7 +151,7 @@ using DeviceGemmInstance = ...@@ -151,7 +151,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -23,7 +23,7 @@ int run(int argc, char* argv[]) ...@@ -23,7 +23,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0; float p_drop = 0.1;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment