Commit 9739ede0 authored by aska-0096's avatar aska-0096
Browse files

temp save

parent 0a808724
...@@ -35,7 +35,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -35,7 +35,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| // ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -481,7 +481,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -481,7 +481,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
/* index_t DstVectorDim, */ 2, /* index_t DstVectorDim, */ 2,
/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, /* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
...@@ -513,7 +513,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -513,7 +513,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(b_block_desc_k0perblock_nperblock_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
......
...@@ -355,5 +355,17 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -355,5 +355,17 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3); c3);
} }
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a,
half16_t b,
float8_t& c)
{
asm volatile("\n \
v_wmma_f32_16x16x16_f16_w32 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment