Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
786a0faa
Commit
786a0faa
authored
Oct 23, 2024
by
Jing Zhang
Browse files
add permute switch as a template
parent
6a2521ea
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
77 additions
and
62 deletions
+77
-62
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
+1
-1
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+25
-21
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+6
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+45
-38
No files found.
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
View file @
786a0faa
...
...
@@ -8,7 +8,7 @@
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
f8_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
floa
t
;
using
CShuffleDataType
=
ck
::
half_
t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
View file @
786a0faa
...
...
@@ -8,7 +8,7 @@
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
pk_i4_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
floa
t
;
using
CShuffleDataType
=
ck
::
half_
t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
...
...
@@ -21,6 +21,8 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
bool
PermuteB
=
true
;
// clang-format off
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
...
...
@@ -38,7 +40,7 @@ using DeviceGemmV2Instance =
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 1,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1
, CDataType, CDataType, false, PermuteB
>;
[[maybe_unused]] static int KPerBlock = 256;
#else
...
...
@@ -52,7 +54,7 @@ using DeviceGemmV2Instance =
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Int
e
rwave
,
ck
::
BlockGemmPipelineVersion
::
v2
>
;
ck
::
BlockGemmPipelineScheduler
::
Intr
a
wave
,
ck
::
BlockGemmPipelineVersion
::
v2
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
[[
maybe_unused
]]
static
int
KPerBlock
=
128
;
#endif
...
...
@@ -123,7 +125,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
0.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
case
2
:
...
...
@@ -136,7 +138,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_
2
<
BDataType
>
{
-
2
,
2
});
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
...
...
@@ -150,32 +152,34 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
#if 1
int
K1
=
KPerBlock
;
int
K0
=
K
/
KPerBlock
;
// int K0, N, K1
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
// weight permute
if
constexpr
(
PermuteB
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
int
K1
=
KPerBlock
;
int
K0
=
K
/
KPerBlock
;
// int K0, N, K1
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
{
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
(
i
*
K
+
(
j
*
K1
+
jj
));
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
{
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
(
i
*
K
+
(
j
*
K1
+
jj
));
}
}
}
}
#else
for
(
int
i
=
0
;
i
<
N
;
i
++
)
else
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
}
}
}
#endif
// vector pk_i4x4 permute
for
(
int
i
=
0
;
i
<
N
;
i
++
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
786a0faa
...
...
@@ -64,7 +64,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
>
struct
DeviceGemm_Xdl_CShuffleV3
:
public
DeviceGemmV2
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -122,7 +124,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
,
PermuteA
,
PermuteB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
786a0faa
...
...
@@ -14,8 +14,6 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#define WEIGHT_PERMUTE
namespace
ck
{
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
...
...
@@ -129,7 +127,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v4
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
>
struct
GridwiseGemm_xdl_cshuffle_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -389,35 +389,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
{
#ifndef WEIGHT_PERMUTE
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
#else
// Weight Tile Permute
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK00
=
BK0_
/
BK01
;
const
auto
b_grid_desc_bk00_n_bk01_bk1_permute
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
const
auto
b_grid_desc_bk0_n_bk1_permute
=
transform_tensor_descriptor
(
b_grid_desc_bk00_n_bk01_bk1_permute
,
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_pass_through_transform
(
make_tuple
(
N
)),
make_pass_through_transform
(
BK1Value
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_grid_desc_bk0_n_bk1_permute
;
#endif
if
constexpr
(
!
PermuteB
)
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// Weight Tile Permute
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
// const index_t BK00 = BK0 / BK01;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK00
=
BK0_
/
BK01
;
const
auto
b_grid_desc_bk00_n_bk01_bk1_permute
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
const
auto
b_grid_desc_bk0_n_bk1_permute
=
transform_tensor_descriptor
(
b_grid_desc_bk00_n_bk01_bk1_permute
,
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_pass_through_transform
(
make_tuple
(
N
)),
make_pass_through_transform
(
BK1Value
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_grid_desc_bk0_n_bk1_permute
;
}
}
}
...
...
@@ -621,12 +625,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
#ifndef WEIGHT_PERMUTE
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
/
BPackedSize
;
#else
const
int
k0_offset
=
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
k0_offset
/
BPackedSize
;
#endif
if
constexpr
(
!
PermuteB
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
/
BPackedSize
;
}
else
{
const
int
k0_offset
=
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
k0_offset
/
BPackedSize
;
}
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment