Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4ac2606c
Unverified
Commit
4ac2606c
authored
Mar 13, 2023
by
zjing14
Committed by
GitHub
Mar 13, 2023
Browse files
Merge pull request #5 from ROCmSoftwarePlatform/mi300
Add support for gfx940 targets.
parents
52f64967
f8a6c69c
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
107 additions
and
9 deletions
+107
-9
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+2
-1
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+59
-1
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+39
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+2
-0
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-3
script/cmake-ck-release.sh
script/cmake-ck-release.sh
+1
-2
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
4ac2606c
...
...
@@ -49,7 +49,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
4ac2606c
...
...
@@ -53,7 +53,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
4ac2606c
...
...
@@ -27,6 +27,8 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16
,
mfma_i32_32x32x8i8
,
mfma_i32_16x16x16i8
,
mfma_i32_32x32x16i8
,
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
};
...
...
@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x16i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x16i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x32i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_16x16x32i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
...
...
@@ -524,17 +570,29 @@ struct MfmaSelector
#endif
}
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
...
include/ck/utility/amd_xdlops.hpp
View file @
4ac2606c
...
...
@@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
template
<
>
struct
intrin_mfma_i32_32x32x16i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x16_i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x32i8
;
template
<
>
struct
intrin_mfma_i32_16x16x32i8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f64_16x16x4f64
;
...
...
@@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
{
#ifdef
__gfx90
a
__
#if
def
ined(__gfx90a__) || defined(
__gfx9
4
0__
)
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
...
...
include/ck/utility/data_type.hpp
View file @
4ac2606c
...
...
@@ -898,6 +898,8 @@ struct vector_type<T, 256>
}
};
using
int64_t
=
long
;
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
...
...
script/cmake-ck-dev.sh
View file @
4ac2606c
...
...
@@ -10,10 +10,9 @@ cmake
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
O
FF
\
-D
GPU_TARGETS
=
"gfx90
a"
\
-D
BUILD_DEV
=
O
N
\
-D
GPU_TARGETS
=
"gfx90
8;gfx90a;gfx940"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
#-D AMDGPU_TARGETS=gfx90a;gfx908
script/cmake-ck-release.sh
View file @
4ac2606c
...
...
@@ -11,9 +11,8 @@ cmake
-D
CMAKE_CXX_FLAGS
=
"-O3"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
OFF
\
-D
GPU_TARGETS
=
"gfx908;gfx90a
"
\
-D
GPU_TARGETS
=
"gfx908;gfx90a
;gfx940"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
#-D AMDGPU_TARGETS=gfx90a;gfx908
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment