Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e35f2c1a
Commit
e35f2c1a
authored
Dec 05, 2024
by
aska-0096
Browse files
optimize perf; enable v4; i4_bufferload_not_solved
parent
27f9ed07
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
597 additions
and
242 deletions
+597
-242
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
+41
-59
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
...n/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
+164
-179
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
+3
-3
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+155
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+3
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+231
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
View file @
e35f2c1a
...
@@ -359,13 +359,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -359,13 +359,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
1
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_per_scale
;
// Local prefetch 1
// Local prefetch 1
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
@@ -381,6 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -381,6 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_block_buf
,
b_scale_thread_buf
[
n0
],
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
b_thread_buf
);
...
@@ -406,10 +400,31 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -406,10 +400,31 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
c_thread_buf_per_scale
.
Clear
();
b_scale_grid_buf
,
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
...
@@ -426,20 +441,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -426,20 +441,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
xdlops_gemm
.
Run
(
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]);
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
...
@@ -459,32 +467,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -459,32 +467,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_block_buf
,
b_scale_thread_buf
[
n0
],
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
b_thread_buf
);
});
});
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_scale_thread_copy
.
Run
(
b_scale_grid_desc
,
b_scale_grid_buf
,
b_scale_thread_desc
,
make_tuple
(
n0
,
I0
),
b_scale_thread_buf
);
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
0
>
{}));
});
if
((
i
+
2
)
%
num_loop_per_scale
==
0
)
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
2
>
{}));
}
else
{
b_scale_thread_copy
.
MoveSrcSliceWindow
(
b_scale_grid_desc
,
b_scale_thread_copy_step
.
At
(
Number
<
1
>
{}));
}
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -495,10 +483,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -495,10 +483,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
c_thread_buf_per_scale
.
Clear
();
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
...
@@ -514,17 +501,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
...
@@ -514,17 +501,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
t
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_buf
(
Number
<
c_offset
>
{})
+=
c_thread_buf_per_scale
[
Number
<
t
>
{}]
*
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
// type_convert<AccDataType>(a_scale_thread_buf[I0]) *
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
type_convert
<
AccDataType
>
(
b_scale_thread_buf
[
n0
]
);
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{})
);
});
});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp
View file @
e35f2c1a
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
e35f2c1a
...
@@ -220,9 +220,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
...
@@ -220,9 +220,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
constexpr
index_t
minimum_occupancy
=
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
&&
?
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
&&
MPerBlock
*
NPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
>
128
*
128
*
64
*
2
)
MPerBlock
*
NPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
<=
128
*
128
*
64
*
2
)
?
1
?
2
:
2
:
1
:
2
;
:
2
;
if
(
has_main_k_block_loop
)
if
(
has_main_k_block_loop
)
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
e35f2c1a
...
@@ -11,6 +11,98 @@
...
@@ -11,6 +11,98 @@
namespace
ck
{
namespace
ck
{
__host__
__device__
inline
half4_t
pki4_to_half4_scale
(
int
q
,
const
ck
::
half2_t
&
scale
)
{
constexpr
int
LO
=
0x000f000f
;
constexpr
int
HI
=
0x00f000f0
;
constexpr
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int
lo
=
amd_assembly_and_or_b32
(
q
,
LO
,
EX
);
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
constexpr
int
SUB
=
0xE408E408
;
//-8
constexpr
int
MUL
=
0x2c002c00
;
// 1/16
constexpr
int
ADD
=
0xd480d480
;
//-79
vector_type
<
half_t
,
4
>
res
;
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
amd_assembly_pk_fma_f16
(
bit_cast
<
half2_t
>
(
hi
),
bit_cast
<
half2_t
>
(
MUL
),
bit_cast
<
half2_t
>
(
ADD
));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{}))
:
"v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})),
"v"
(
scale
));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{}))
:
"v"
(
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})),
"v"
(
scale
));
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
// Further fuse the scale into inline assembly, sanity failed
#if 0
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half_t& scale)
{
constexpr int LO = 0x000f000f;
constexpr int HI = 0x00f000f0;
constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = amd_assembly_and_or_b32(q, LO, EX);
int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
// constexpr int SUB = 0xE408E408; //-8
// constexpr int MUL = 0x2c002c00; // 1/16
// constexpr int ADD = 0xd480d480; //-79
constexpr half_t SUB = bit_cast<half_t>(static_cast<uint16_t>(0xE408));
constexpr half_t MUL = bit_cast<half_t>(static_cast<uint16_t>(0x2c00));
constexpr half_t ADD = bit_cast<half_t>(static_cast<uint16_t>(0xd480));
vector_type<half_t, 2> scale_2;
scale_2.template AsType<half_t>()(Number<0>{}) = scale;
scale_2.template AsType<half_t>()(Number<1>{}) = scale;
vector_type<half_t, 2> sub_2;
sub_2.template AsType<half_t>()(Number<0>{}) = SUB * scale;
sub_2.template AsType<half_t>()(Number<1>{}) = SUB * scale;
vector_type<half_t, 2> mul_2;
mul_2.template AsType<half_t>()(Number<0>{}) = MUL * scale;
mul_2.template AsType<half_t>()(Number<1>{}) = MUL * scale;
vector_type<half_t, 2> add_2;
add_2.template AsType<half_t>()(Number<0>{}) = ADD * scale;
add_2.template AsType<half_t>()(Number<1>{}) = ADD * scale;
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(lo),
scale_2.template AsType<half2_t>()(Number<0>{}),
sub_2.template AsType<half2_t>()(Number<0>{}));
res.template AsType<half2_t>()(Number<1>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(hi),
mul_2.template AsType<half2_t>()(Number<0>{}),
add_2.template AsType<half2_t>()(Number<0>{}));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<0>{}))
// : "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<1>{}))
// : "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
return res.template AsType<half4_t>()[Number<0>{}];
}
#endif
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
{
constexpr
int
LO
=
0x000f000f
;
constexpr
int
LO
=
0x000f000f
;
...
@@ -119,6 +211,69 @@ struct PassThroughPack8
...
@@ -119,6 +211,69 @@ struct PassThroughPack8
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
half_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
struct
DequantPack8
{
template
<
typename
Y
,
typename
X
,
typename
Z
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
,
const
Z
&
z
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
,
const
ck
::
half2_t
&
z
)
const
{
#if 1
int
x_permute
=
0
;
int
bits4_0
=
(
bit_cast
<
int
>
(
x
)
>>
0
)
&
0xF
;
int
bits4_1
=
(
bit_cast
<
int
>
(
x
)
>>
4
)
&
0xF
;
int
bits4_2
=
(
bit_cast
<
int
>
(
x
)
>>
8
)
&
0xF
;
int
bits4_3
=
(
bit_cast
<
int
>
(
x
)
>>
12
)
&
0xF
;
int
bits4_4
=
(
bit_cast
<
int
>
(
x
)
>>
16
)
&
0xF
;
int
bits4_5
=
(
bit_cast
<
int
>
(
x
)
>>
20
)
&
0xF
;
int
bits4_6
=
(
bit_cast
<
int
>
(
x
)
>>
24
)
&
0xF
;
int
bits4_7
=
(
bit_cast
<
int
>
(
x
)
>>
28
)
&
0xF
;
x_permute
|=
(
bits4_1
<<
0
);
x_permute
|=
(
bits4_3
<<
4
);
x_permute
|=
(
bits4_5
<<
8
);
x_permute
|=
(
bits4_7
<<
12
);
x_permute
|=
(
bits4_0
<<
16
);
x_permute
|=
(
bits4_2
<<
20
);
x_permute
|=
(
bits4_4
<<
24
);
x_permute
|=
(
bits4_6
<<
28
);
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4_scale
(
x_permute
,
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4_scale
(
x_permute
>>
8
,
z
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#elif 1
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
#else
vector_type
<
half_t
,
8
>
dst
;
vector_type
<
half_t
,
8
>
dst
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
e35f2c1a
...
@@ -1914,7 +1914,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1914,7 +1914,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
make_multi_index
(
block_n_id
*
NPerBlock
/
ScaleBlockN
+
b_thread_offset
,
0
));
constexpr
auto
b_scale_thread_slice_copy_step
=
constexpr
auto
b_scale_thread_slice_copy_step
=
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
make_tuple
(
make_multi_index
(
NWaves
*
NPerXdl
,
0
),
make_multi_index
(
-
NPerBlock
,
0
),
make_multi_index
(
-
NPerBlock
,
1
));
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
const
index_t
num_k_block_per_scale
=
ScaleBlockK
/
KPerBlock
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
e35f2c1a
...
@@ -1252,6 +1252,237 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1252,6 +1252,237 @@ struct ThreadwiseTensorSliceTransfer_v4
});
});
}
}
// Fuse scale
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
SrcBuffer
&
src_buf
,
const
DstData
&
scale
,
const
DstDesc
&
,
const
DstOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
constexpr
auto
dst_origin_idx
=
to_multi_index
(
DstOriginIdx
{});
// scalar per access of each dim
constexpr
auto
src_scalar_per_access
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
SrcScalarPerVector
>
{};
}
else
{
return
Number
<
1
>
{};
}
},
Number
<
nDim
>
{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
1
>
{};
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
dim_access_order
)
*
src_scalar_per_access
;
#endif
// src coordinate
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_step
=
make_tensor_coordinate_step
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
static_assert
(
false
,
""
);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type
<
DstData
,
2
>
scale_vector
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
scale
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
1
>
{})
=
scale
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
scale_v_t
=
typename
vector_type_maker_t
<
DstData
,
2
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
DequantPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
],
scale_vector
.
template
AsType
<
scale_v_t
>()[
Number
<
0
>
{}]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
f8_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
2
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack2
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
});
}
template
<
typename
SrcSliceMoveStepIdx
>
template
<
typename
SrcSliceMoveStepIdx
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
SrcSliceMoveStepIdx
&
src_slice_move_step_idx
)
const
SrcSliceMoveStepIdx
&
src_slice_move_step_idx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment