Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bf75259f
Commit
bf75259f
authored
Aug 16, 2023
by
aska-0096
Browse files
New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm
parent
061009a3
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
116 additions
and
939 deletions
+116
-939
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
...ensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
+0
-624
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
...block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
+34
-33
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+0
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-9
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+0
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+0
-219
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
.../thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
+77
-45
No files found.
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
deleted
100644 → 0
View file @
061009a3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
View file @
bf75259f
...
@@ -57,7 +57,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
...
@@ -57,7 +57,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
scale_thread_slice_lengths
=
BlockScaleSliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
scale_thread_slice_lengths
=
BlockScaleSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
...
@@ -92,7 +93,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
...
@@ -92,7 +93,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
static_assert
(
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{}
&&
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{}
&&
is_same
<
BlockScaleSliceLengths
,
decltype
(
scale_thread_slice_lengths
*
ThreadClusterLengths
{})
>
{}
,
is_same
<
BlockScaleSliceLengths
,
decltype
(
scale_thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
...
@@ -108,8 +110,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
...
@@ -108,8 +110,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetScaleSliceOrigin
(
scale_desc
,
threadwise_transfer_
.
SetScaleSliceOrigin
(
scale_block_slice_origin
+
thread_data_idx_begin
);
scale_desc
,
scale_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
...
@@ -129,8 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
...
@@ -129,8 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
// With the assumption, scale scratch is always one
// With the assumption, scale scratch is always one
template
<
typename
ScaleBuffer
>
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
const
ScaleBuffer
&
scale_buf
)
{
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
bf75259f
...
@@ -677,7 +677,6 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -677,7 +677,6 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
dequant_v1
,
"dequant_v1"
},
{
PipelineVersion
::
weight_only
,
"weight_only"
}};
{
PipelineVersion
::
weight_only
,
"weight_only"
}};
// clang-format off
// clang-format off
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
bf75259f
...
@@ -404,18 +404,11 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
...
@@ -404,18 +404,11 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
half_2
[
0
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_01
);
half_2
[
0
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_01
);
half_2
[
1
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_23
);
half_2
[
1
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_23
);
// static constexpr ck::half_t fp16_subtract = -1152;
// Output.template AsType<ck::half_t>()(Number<0>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<1>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<2>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<3>{}) += fp16_subtract;
// inline assembly get very poor performance as no chance to global scheduling
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]
\n
"
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
:
"=v"
(
half_2
[
0
])
:
"=v"
(
half_2
[
0
])
:
"v"
(
half_2
[
0
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
:
"v"
(
half_2
[
0
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]
\n
"
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
:
"=v"
(
half_2
[
1
])
:
"=v"
(
half_2
[
1
])
:
"v"
(
half_2
[
1
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
:
"v"
(
half_2
[
1
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
bf75259f
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
bf75259f
...
@@ -12,7 +12,6 @@ enum struct PipelineVersion
...
@@ -12,7 +12,6 @@ enum struct PipelineVersion
{
{
v1
,
v1
,
v2
,
v2
,
dequant_v1
,
weight_only
,
weight_only
,
};
};
...
@@ -38,10 +37,6 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -38,10 +37,6 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
{
return
GridwiseGemmPipeline_v2
{};
return
GridwiseGemmPipeline_v2
{};
}
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
dequant_v1
)
{
return
GridwiseGemmPipeline_v1_dequant
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
weight_only
)
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
weight_only
)
{
{
return
GridwiseGemmPipeline_v1_WeightOnly
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
return
GridwiseGemmPipeline_v1_WeightOnly
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
bf75259f
...
@@ -550,225 +550,6 @@ struct GridwiseGemmPipeline_v1<1, false, false>
...
@@ -550,225 +550,6 @@ struct GridwiseGemmPipeline_v1<1, false, false>
}
}
};
};
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1_dequant
;
template
<
>
struct
GridwiseGemmPipeline_v1_dequant
<
1
,
true
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleBlockDesc
,
typename
ScaleBlockTransfer
,
typename
ScaleGridBuffer
,
typename
ScaleBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleBlockDesc
&
scale_block_desc
,
ScaleBlockTransfer
&
scale_blockwise_copy
,
const
ScaleGridBuffer
&
scale_grid_buf
,
ScaleBlockBuffer
&
scale_block_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
scale_blockwise_copy
.
RunRead
(
scale_grid_desc
,
scale_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
scale_blockwise_copy
.
RunWrite
(
scale_block_desc
,
scale_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
}
}
};
template
<
>
struct
GridwiseGemmPipeline_v1_dequant
<
1
,
true
,
false
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleBlockDesc
,
typename
ScaleBlockTransfer
,
typename
ScaleGridBuffer
,
typename
ScaleBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleBlockDesc
&
scale_block_desc
,
ScaleBlockTransfer
&
scale_blockwise_copy
,
const
ScaleGridBuffer
&
scale_grid_buf
,
ScaleBlockBuffer
&
scale_block_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
b_block_buf_switch
=
b_block_buf
;
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf
);
scale_blockwise_copy
.
Run
(
scale_grid_desc
,
scale_grid_buf
,
scale_block_desc
,
b_block_origin_idx
,
scale_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf_switch
);
block_sync_lds
();
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_buf
=
b_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1_WeightOnly
;
struct
GridwiseGemmPipeline_v1_WeightOnly
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
View file @
bf75259f
...
@@ -114,7 +114,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -114,7 +114,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
}
__device__
void
SetScaleSliceOrigin
(
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin_idx
)
__device__
void
SetScaleSliceOrigin
(
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin_idx
)
{
{
scale_coord_
=
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin_idx
);
scale_coord_
=
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin_idx
);
}
}
...
@@ -274,8 +275,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -274,8 +275,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
}
}
template
<
typename
ScaleBuffer
>
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
const
ScaleBuffer
&
scale_buf
)
{
{
static_assert
(
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
static_assert
(
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
...
@@ -358,11 +358,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -358,11 +358,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
scale_scalar_per_access
;
scale_scalar_per_access
;
}();
}();
constexpr
auto
scale_data_idx_seq
=
generate_sequence_v2
(
constexpr
auto
scale_data_idx_seq
=
[
&
](
auto
i
)
{
return
Number
<
scale_data_idx
[
i
]
>
{};
},
Number
<
scale_data_idx
.
Size
()
>
{});
generate_sequence_v2
([
&
](
auto
i
)
{
return
Number
<
scale_data_idx
[
i
]
>
{};
},
Number
<
scale_data_idx
.
Size
()
>
{});
const
bool
is_scale_valid
=
const
bool
is_scale_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
scale_desc
,
scale_coord_
);
scale_desc
,
scale_coord_
);
using
scale_vector_type
=
vector_type_maker_t
<
ScaleData
,
ScaleScalarPerVector
>
;
using
scale_vector_type
=
vector_type_maker_t
<
ScaleData
,
ScaleScalarPerVector
>
;
using
scale_vector_t
=
typename
scale_vector_type
::
type
;
using
scale_vector_t
=
typename
scale_vector_type
::
type
;
...
@@ -372,8 +373,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -372,8 +373,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
scale_buf
.
template
Get
<
scale_vector_t
>(
scale_coord_
.
GetOffset
(),
is_scale_valid
)};
scale_buf
.
template
Get
<
scale_vector_t
>(
scale_coord_
.
GetOffset
(),
is_scale_valid
)};
// copy data from scale_vector_container into scale_thread_scratch_
// copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_
scale_thread_scratch_
.
template
SetAsType
<
scale_vector_t
>(
.
template
SetAsType
<
scale_vector_t
>(
scale_data_idx_seq
,
scale_vector_container
.
template
AsType
<
scale_vector_t
>()[
I0
]);
scale_data_idx_seq
,
scale_vector_container
.
template
AsType
<
scale_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
@@ -381,7 +381,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -381,7 +381,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_scale_access_idx
[
i
]
<
ordered_scale_access_lengths
[
i
]
-
1
;
move_on_dim_
(
i
)
=
ordered_scale_access_idx
[
i
]
<
ordered_scale_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
move_on_dim_
(
i
)
&=
...
@@ -399,13 +400,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -399,13 +400,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
scale_desc
,
scale_desc
,
scale_coord_
,
scale_forward_steps
[
scale_dim_access_order
[
i
]]);
scale_coord_
,
scale_forward_steps
[
scale_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
scale_desc
,
scale_desc
,
scale_coord_
,
scale_backward_steps
[
scale_dim_access_order
[
i
]]);
scale_coord_
,
scale_backward_steps
[
scale_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -500,20 +503,46 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -500,20 +503,46 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
// do data transpose
// do data transpose
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
src_vector_refs
,
dst_vector_refs
);
// do fast numeric convert
src_converted_thread_scratch_
.
template
SetAsType
<
SrcThreadConvertedScratch
::
V
>(
access_idx
,
fast_numeric_converter
(
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
SrcThreadScratch
::
V
>(
access_idx
)));
});
});
}
}
// Do fast numeric convert
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst_idle
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_converted_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
src_converted_vector_t
=
typename
src_converted_vector_type
::
type
;
// Vector-wise type convert
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
auto
src_vector_container
=
src_vector_type
{
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
src_vector_t
>(
access_idx
)};
auto
src_converted_vector_container
=
src_converted_vector_type
{
fast_numeric_converter
(
src_vector_container
)};
src_converted_thread_scratch_
.
template
SetAsType
<
src_converted_vector_t
>(
access_idx
,
src_converted_vector_container
.
template
AsType
<
src_converted_vector_t
>()[
I0
]);
});
// Element-scale operation, expect packed multiplication
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
// Scale is dynamic, could not implement through element_op.
DstData
dst_v
;
DstData
dst_v
;
constexpr
auto
scale_idx
=
Sequence
<
I0
,
idx
.
At
(
1
),
I0
>
{};
constexpr
auto
scale_idx
=
Sequence
<
I0
,
idx
.
At
(
1
),
I0
>
{};
src_element_op_
(
dst_v
,
src_converted_thread_scratch_
[
idx
]
*
scale_thread_scratch_
[
scale_idx
]);
// printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
src_element_op_
(
dst_v
,
src_converted_thread_scratch_
[
idx
]
*
scale_thread_scratch_
[
scale_idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
dst_thread_scratch_
(
idx
)
=
dst_v
;
});
});
#endif
#endif
...
@@ -978,13 +1007,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -978,13 +1007,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
private:
private:
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
scale_thread_scratch_desc_
=
decltype
(
GetScaleThreadScratchDescriptor
()){};
static
constexpr
auto
scale_thread_scratch_desc_
=
decltype
(
GetScaleThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
/*
/*
template <bool kLastDim>
template <bool kLastDim>
struct ScaleThreadScratchDesc{};
struct ScaleThreadScratchDesc{};
*/
*/
// Registers, contain raw data loaded from global buffer
// Registers, contain raw data loaded from global buffer
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
@@ -994,7 +1024,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -994,7 +1024,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
true
>
;
true
>
;
// Registers, contain fast converted data
// Registers, contain fast converted data
using
SrcThreadConvertedScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
using
SrcThreadConvertedScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstData
,
SrcScalarPerVector
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
decltype
(
src_thread_scratch_desc_
),
...
@@ -1014,7 +1045,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
...
@@ -1014,7 +1045,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
decltype
(
dst_thread_scratch_desc_
),
decltype
(
dst_thread_scratch_desc_
),
true
>
;
true
>
;
using
FastTypeConverter
=
tensor_operation
::
element_wise
::
FastNumericArrayConverter
<
SrcData
,
DstData
,
SrcScalarPerVector
>
;
using
FastTypeConverter
=
tensor_operation
::
element_wise
::
FastNumericArrayConverter
<
SrcData
,
DstData
,
SrcScalarPerVector
>
;
StaticallyIndexedArray
<
SrcThreadScratch
,
NumThreadScratch
>
src_thread_scratch_tuple_
;
StaticallyIndexedArray
<
SrcThreadScratch
,
NumThreadScratch
>
src_thread_scratch_tuple_
;
SrcThreadConvertedScratch
src_converted_thread_scratch_
;
SrcThreadConvertedScratch
src_converted_thread_scratch_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment