Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
941d1f7c
"official/nlp/tools/tokenization.py" did not exist on "955389a9f2acdf4f48700bfc22b6350c360ad3b9"
Unverified
Commit
941d1f7c
authored
Jun 27, 2024
by
Illia Silin
Committed by
GitHub
Jun 27, 2024
Browse files
Merging the gfx12 code into public repo. (#1362)
parent
a32b1bc6
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
129 additions
and
11 deletions
+129
-11
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+82
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-1
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+17
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+4
-1
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+4
-4
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+2
-2
test/CMakeLists.txt
test/CMakeLists.txt
+2
-2
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+1
-1
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+16
-0
No files found.
include/ck/utility/amd_wmma.hpp
View file @
941d1f7c
...
...
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
}
};
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12
(
neg_a
,
bit_cast
<
int32x2_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x2_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
941d1f7c
...
...
@@ -203,7 +203,7 @@ struct vector_type<T, 1>
}
};
int
static
err
=
0
;
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
{
...
...
include/ck/utility/synchronization.hpp
View file @
941d1f7c
...
...
@@ -10,12 +10,20 @@ namespace ck {
__device__
void
block_sync_lds
()
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_barrier
();
#endif
#else
__syncthreads
();
#endif
...
...
@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
__device__
void
block_sync_lds_direct_load
()
{
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_vmcnt 0x0
\n
\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"\
s_waitcnt vmcnt(0)
\n
\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
#endif
}
__device__
void
s_nop
()
...
...
include/ck_tile/core/config.hpp
View file @
941d1f7c
...
...
@@ -17,6 +17,9 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
...
...
@@ -155,7 +158,7 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) // for GPU code
#elif defined(__gfx11__)
|| defined(__gfx12__)
// for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
941d1f7c
...
...
@@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
# Do not build WMMA instances if gfx11 targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST
_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"_wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND NOT GPU
_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list})
message
(
"Found only xdl instances, but gfx9 is not on the targets list. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx1
1
"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY WMMA_KERNELS"
)
AND
(
NOT
GPU_TARGETS MATCHES
"gfx11"
)
AND
(
NOT GPU
_TARGETS MATCHES
"gfx1
2
"
))
message
(
"Found only wmma instances, but gfx11 is not on the targets list. Skipping."
)
set
(
add_inst 0
)
endif
()
...
...
@@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list})
message
(
"Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
GPU_TARGETS MATCHES
"gfx12"
)
AND
(
NOT GPU
_TARGETS MATCHES
"gfx9"
))
message
(
"Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"XDL_DL_WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
)
AND
(
NOT DEFINED DL_KERNELS
))
if
((
"
${
cmake_instance
}
"
MATCHES
"XDL_DL_WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
GPU_TARGETS MATCHES
"gfx12"
)
AND
(
NOT GPU
_TARGETS MATCHES
"gfx9"
)
AND
(
NOT DEFINED DL_KERNELS
))
message
(
"Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping."
)
set
(
add_inst 0
)
endif
()
...
...
profiler/src/CMakeLists.txt
View file @
941d1f7c
...
...
@@ -59,7 +59,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endif
()
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx9"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
OR GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
endif
()
...
...
@@ -134,7 +134,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_weight_instance
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
endif
()
...
...
test/CMakeLists.txt
View file @
941d1f7c
...
...
@@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST
_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND NOT GPU
_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -139,7 +139,7 @@ function(add_gtest_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST
_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND NOT GPU
_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
941d1f7c
...
...
@@ -44,7 +44,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
}
}
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
// on gfx11 only support for 3d is implemented
if
constexpr
(
NDimSpatial
{}
!=
3
)
...
...
test/wmma_op/wmma_op_util.hpp
View file @
941d1f7c
...
...
@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
+
16
*
16
]
=
b_temp
[
ele
];
}
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
#endif
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
...
...
@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
a_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
];
}
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
#endif
// sync threads, similar to mma_sync
// __syncthreads();
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment