"docs/git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "d63eb9c9d65711617c6d2104c0af61ea47392f03"
Commit 8efd363f authored by aska-0096's avatar aska-0096
Browse files

compiler issue fixed + increase tile size

parent 40ec8e5d
...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| // ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -143,6 +143,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -143,6 +143,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
"wrong!"); "wrong!");
} }
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
}
// Provide dimension size // Provide dimension size
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
...@@ -550,12 +573,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -550,12 +573,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
s_nop(); // s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); // s_nop();
}); });
if constexpr(KPerBlock > WmmaK) if constexpr(KPerBlock > WmmaK)
{ {
...@@ -590,12 +613,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -590,12 +613,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = c_thread_desc_.CalculateOffset( constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
s_nop(); // s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); // s_nop();
}); });
// Read Consumed Next inner loop A // Read Consumed Next inner loop A
...@@ -626,12 +649,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -626,12 +649,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
s_nop(); // s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); // s_nop();
}); });
// Read Consumed Next inner loop B // Read Consumed Next inner loop B
b_thread_copy_.Run( b_thread_copy_.Run(
...@@ -662,12 +685,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -662,12 +685,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
s_nop(); // s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); // s_nop();
}); });
if constexpr(KPerBlock > WmmaK) if constexpr(KPerBlock > WmmaK)
{ {
...@@ -702,12 +725,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -702,12 +725,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
s_nop(); // s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); // s_nop();
}); });
// Col Repeatation // Col Repeatation
...@@ -728,12 +751,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -728,12 +751,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
s_nop(); // s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); // s_nop();
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment