Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e053e947
Commit
e053e947
authored
Oct 21, 2024
by
Jing Zhang
Browse files
weight permute
parent
82bb8dde
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
88 additions
and
9 deletions
+88
-9
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+74
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+6
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+6
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+2
-2
No files found.
example/01_gemm/run_gemm_example_v2.inc
View file @
e053e947
...
...
@@ -134,6 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_permute
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
{
...
...
@@ -169,8 +170,80 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
{
int
input
[
8
];
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
i4x2
=
b_k_n
(
j
+
k
*
2
,
i
);
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
// for(int k = 1; k <= 4; k++)
{
int
hi
=
input
[
2
];
int
lo
=
input
[
0
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
0
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
6
];
int
lo
=
input
[
4
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
2
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
3
];
int
lo
=
input
[
1
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
4
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
7
];
int
lo
=
input
[
5
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
6
,
i
)
=
i4x2
;
}
}
}
#if 0
ck
::
pk_i4_t
i4s
[
4
];
i4s
[
0
]
=
0xa8
;
i4s
[
1
]
=
0xec
;
i4s
[
2
]
=
0xb9
;
i4s
[
3
]
=
0xfd
;
ck
::
vector_type
<
ck
::
half_t
,
8
>
result
;
result
.
template
AsType
<
ck
::
half4_t
>
()(
ck
::
Number
<
0
>
{})
=
ck
::
pki4_to_half4
(
ck
::
bit_cast
<
int
>
(
i4s
));
result
.
template
AsType
<
ck
::
half4_t
>
()(
ck
::
Number
<
1
>
{})
=
ck
::
pki4_to_half4
(
ck
::
bit_cast
<
int
>
(
i4s
)
>>
8
);
printf
(
"%f %f %f %f %f %f %f %f
\n
"
,
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
0
>
{}]),
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
1
>
{}]),
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
2
>
{}]),
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
3
>
{}]),
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
4
>
{}]),
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
5
>
{}]),
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
6
>
{}]),
ck
::
type_convert
<
float
>
(
result
.
template
AsType
<
ck
::
half_t
>
()[
ck
::
Number
<
7
>
{}])
);
#endif
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
_permute
.
mData
.
data
());
DeviceMem
workspace
;
auto
a_element_op
=
AElementOp
{};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
e053e947
...
...
@@ -10,10 +10,8 @@
#include "ck/utility/amd_inline_asm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
...
...
@@ -40,7 +38,7 @@ __device__ inline half4_t pki4_to_half4(int q)
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
{
#if 0
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
...
...
@@ -58,7 +56,7 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
int
x_l
=
(
x_u8
&
0x0f
);
int
x_h
=
(
x_u8
&
0xf0
)
<<
12
;
const
int
EX
=
0x64006400
;
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
(
x_l
|
x_h
)
|
EX
;
...
...
@@ -67,6 +65,9 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
#endif
}
namespace
tensor_operation
{
namespace
element_wise
{
struct
PassThroughPack8
{
template
<
typename
Y
,
typename
X
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
e053e947
...
...
@@ -396,8 +396,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
#else
const
index_t
N0
=
N
/
NPerBlock
;
const
index_t
N1
=
NPerBlock
;
const
index_t
N0
=
N
/
N1
;
const
auto
b_grid_desc_n0_bk0_n1_bk1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N0
,
BK0
,
N1
,
BK1Value
));
...
...
@@ -614,7 +614,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
#if 1
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
/
BPackedSize
;
#else
const
int
k0_offset
=
karg
.
KRead
*
NPerBlock
;
b_k_split_offset
=
blockIdx
.
z
*
k0_offset
/
BPackedSize
;
#endif
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
e053e947
...
...
@@ -78,7 +78,7 @@ struct ReferenceGemm : public device::BaseOperator
{
pk_i4_t
i4x2
=
arg
.
a_m_k_
(
m
,
k
);
int8_t
i4
=
0
;
if
(
k
%
2
==
0
)
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
...
...
@@ -99,7 +99,7 @@ struct ReferenceGemm : public device::BaseOperator
{
pk_i4_t
i4x2
=
arg
.
b_k_n_
(
k
,
n
);
int8_t
i4
=
0
;
if
(
k
%
2
==
0
)
if
(
k
%
2
==
1
)
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment