"...composable_kernel_rocm.git" did not exist on "561ec12f4abf7ae72cecf3761c7b6ac2e58a5ed3"
Commit 3480e42b authored by guangzlu's avatar guangzlu
Browse files

fixed bugs after merge

parent e439b369
...@@ -105,7 +105,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia ...@@ -105,7 +105,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
...@@ -190,7 +190,8 @@ using DeviceGemmInstanceBWD = ...@@ -190,7 +190,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -249,7 +250,8 @@ using DeviceGemmInstanceBWD = ...@@ -249,7 +250,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec,
Deterministic>; // MaskingSpecialization
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -329,7 +331,8 @@ using DeviceGemmInstanceBWD = ...@@ -329,7 +331,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -388,7 +391,8 @@ using DeviceGemmInstanceBWD = ...@@ -388,7 +391,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec,
Deterministic>; // MaskingSpecialization
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
......
...@@ -146,6 +146,9 @@ __global__ void ...@@ -146,6 +146,9 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx,
MRaw,
NRaw,
i); i);
} }
} }
...@@ -178,6 +181,9 @@ __global__ void ...@@ -178,6 +181,9 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx,
MRaw,
NRaw,
0); 0);
} }
...@@ -1007,8 +1013,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1007,8 +1013,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_drop_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment