Commit 3480e42b authored by guangzlu's avatar guangzlu
Browse files

fixed bugs after merge

parent e439b369
......@@ -105,7 +105,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
static constexpr bool Deterministic = false;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
......@@ -190,7 +190,8 @@ using DeviceGemmInstanceBWD =
NumDimN,
NumDimK,
NumDimO,
DataType,
InputDataType,
OutputDataType,
GemmDataType,
ZDataType,
LSEDataType,
......@@ -249,7 +250,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec,
Deterministic>; // MaskingSpecialization
#elif(DIM <= 64)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
......@@ -329,7 +331,8 @@ using DeviceGemmInstanceBWD =
NumDimN,
NumDimK,
NumDimO,
DataType,
InputDataType,
OutputDataType,
GemmDataType,
ZDataType,
LSEDataType,
......@@ -388,7 +391,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec,
Deterministic>; // MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......
......@@ -146,6 +146,9 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
i);
}
}
......@@ -178,6 +181,9 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
0);
}
......@@ -1007,8 +1013,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.c0_matrix_mask_,
arg.p_drop_,
arg.seed_,
arg.offset_);
arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment