"examples/community/stable_diffusion_controlnet_img2img.py" did not exist on "44e56de9aaaa103ad11ca2953dc86ba6f64ba5d4"
Commit 2a0592a9 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed gfx12

parent 7cb8a89f
...@@ -19,49 +19,48 @@ using AElementOp = PassThrough; ...@@ -19,49 +19,48 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout,
ALayout, BLayout,
BLayout, CLayout,
CLayout, ADataType,
ADataType, BDataType,
BDataType, CDataType,
CDataType, AccDataType,
AccDataType, CShuffleDataType,
CShuffleDataType, AElementOp,
AElementOp, BElementOp,
BElementOp, CElementOp,
CElementOp, GemmDefault,
GemmDefault, 1,
2, // Prefetch stage 32,
256, // BlockSize 16,
128, // MPerBlock 32,
256, // NPerBlock 64,
64, // KPerBlock 8,
8, // K1 16,
16, // MPerWmma 16,
16, // NPerWmma 1,
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 2,
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave S<2, 16, 1>,
S<4, 64, 1>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, 2,
2, 8,
8, 8,
8, true,
true, S<2, 16, 1>,
S<4, 64, 1>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, 2,
2, 8,
8, 8,
8, true,
true, 1,
1, // C shuffle (M Repeat) Per store 1,
1, // C shuffle (N Repeat) Per store S<1, 16, 1, 2>,
S<1, 32, 1, 8>, 8>;
8>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -73,6 +73,7 @@ struct BlockwiseGemmWMMA ...@@ -73,6 +73,7 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 1; static constexpr index_t A_KRow = 1;
static constexpr index_t B_KRow = 1; static constexpr index_t B_KRow = 1;
#endif #endif
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
......
...@@ -136,6 +136,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12, ...@@ -136,6 +136,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// static constexpr index_t src_b_data_size = 2; // static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4; // static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1; static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma; static constexpr index_t num_thread_per_subgroups = n_per_wmma;
...@@ -565,14 +566,20 @@ struct WmmaGemm ...@@ -565,14 +566,20 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
// return GetLaneIdUnderSubGroup(); #ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
#endif
} }
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex()
{ {
// return GetLaneIdUnderSubGroup(); #ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
#endif
} }
__device__ static CIndex GetBeginOfThreadBlk() __device__ static CIndex GetBeginOfThreadBlk()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment