Commit fdfb2f61 authored by root's avatar root
Browse files

fix data type issue, now met fast_numeric_converter call issue

parent c2a77a07
...@@ -80,7 +80,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X ...@@ -80,7 +80,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
...@@ -798,24 +798,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale ...@@ -798,24 +798,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale
private: private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = using SrcThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, // apply data_convert with SrcThreadScratch SrcData, // apply data_convert with SrcThreadScratch
SrcScalarPerVector, SrcScalarPerVector,
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true>; true>;
using SrcOOBThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool, // apply data_convert with SrcThreadScratch
1,
decltype(src_oob_thread_scratch_desc_),
true>;
// Registers, contain fast converted data // Registers, contain fast converted data
using SrcThreadConvertedScratch = using SrcThreadConvertedScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
...@@ -834,7 +825,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale ...@@ -834,7 +825,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale
FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>; FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
SrcThreadConvertedScratch src_converted_thread_scratch_; SrcThreadConvertedScratch src_converted_thread_scratch_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment