Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9de3a085
Commit
9de3a085
authored
Oct 28, 2024
by
Jing Zhang
Browse files
format
parent
a6ccd2ec
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
32 deletions
+33
-32
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
+7
-7
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+7
-7
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+12
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+5
-5
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+2
-1
No files found.
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
View file @
9de3a085
...
@@ -69,7 +69,7 @@ using DeviceGemmV2Instance =
...
@@ -69,7 +69,7 @@ using DeviceGemmV2Instance =
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AccDataType
,
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
View file @
9de3a085
...
@@ -58,7 +58,7 @@ using DeviceGemmV2Instance =
...
@@ -58,7 +58,7 @@ using DeviceGemmV2Instance =
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AccDataType
,
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
9de3a085
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
namespace
ck
{
namespace
ck
{
//https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
//
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
...
@@ -54,7 +54,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
...
@@ -54,7 +54,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
int
q
)
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
int
q
)
{
{
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
//uint32_t i8s = q & 0xf0f0f0f;
//
uint32_t i8s = q & 0xf0f0f0f;
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
...
@@ -73,8 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
...
@@ -73,8 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
fp32_intermediates
[
3
]
-=
8388616.
f
;
fp32_intermediates
[
3
]
-=
8388616.
f
;
vector_type
<
bhalf_t
,
4
>
res
;
vector_type
<
bhalf_t
,
4
>
res
;
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
));
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
bit_cast
<
bhalf2_t
>
(
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
));
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
));
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
));
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
}
}
...
@@ -94,7 +96,6 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
...
@@ -94,7 +96,6 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
return
res
.
template
AsType
<
bhalf2_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
bhalf2_t
>()[
Number
<
0
>
{}];
}
}
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
...
@@ -137,7 +138,6 @@ struct PassThroughPack8
...
@@ -137,7 +138,6 @@ struct PassThroughPack8
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#else
#else
vector_type
<
bhalf_t
,
8
>
dst
;
vector_type
<
bhalf_t
,
8
>
dst
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
9de3a085
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
9de3a085
...
@@ -838,7 +838,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -838,7 +838,8 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
{
{
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment