Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
36c38ad9
Commit
36c38ad9
authored
Oct 21, 2022
by
aska-0096
Browse files
wmma_op + unit test
parent
685860c2
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
330 additions
and
1 deletion
+330
-1
include/ck/ck.hpp
include/ck/ck.hpp
+10
-1
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+136
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+5
-0
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/wmma_op/CMakeLists.txt
test/wmma_op/CMakeLists.txt
+2
-0
test/wmma_op/wmma_op.cpp
test/wmma_op/wmma_op.cpp
+176
-0
No files found.
include/ck/ck.hpp
View file @
36c38ad9
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
// check GPU target
// check GPU target
#ifdef __HIP_DEVICE_COMPILE__
#ifdef __HIP_DEVICE_COMPILE__
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__))
defined(__gfx90a__) || defined(__gfx1030__)
|| defined(__gfx1100__)
)
#error Not supported target
#error Not supported target
#endif
#endif
#endif
#endif
...
@@ -38,6 +38,8 @@
...
@@ -38,6 +38,8 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000
#endif
#endif
// FMA instruction
// FMA instruction
...
@@ -62,6 +64,13 @@
...
@@ -62,6 +64,13 @@
#define CK_USE_AMD_MFMA_BF16_1K_OP
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
#define CK_USE_AMD_BUFFER_LOAD 1
...
...
include/ck/utility/amd_wmma.hpp
0 → 100644
View file @
36c38ad9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "data_type.hpp"
namespace
ck
{
// wave32 only
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f16_16x16x16_f16_w32
;
template
<
>
struct
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
opsel
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
opsel
);
}
};
// src: bf16, dst: bf32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
;
template
<
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
opsel
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
opsel
);
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
;
template
<
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bool
neg_a
,
const
int8x16_t
&
reg_a
,
const
bool
neg_b
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
clamp
)
{
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// src: iu4, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_i32_16x16x16_iu4_w32
;
template
<
>
struct
intrin_wmma_i32_16x16x16_iu4_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bool
neg_a
,
const
int4x16_t
&
reg_a
,
const
bool
neg_b
,
const
int4x16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
clamp
)
{
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
#endif
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
36c38ad9
...
@@ -942,6 +942,11 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -942,6 +942,11 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// i4
using
int4x16_t
=
typename
vector_type
<
int4_t
,
16
>::
type
;
#endif
// Convert X to Y
// Convert X to Y
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
...
...
test/CMakeLists.txt
View file @
36c38ad9
...
@@ -52,3 +52,4 @@ add_subdirectory(block_to_ctile_map)
...
@@ -52,3 +52,4 @@ add_subdirectory(block_to_ctile_map)
add_subdirectory
(
softmax
)
add_subdirectory
(
softmax
)
add_subdirectory
(
normalization
)
add_subdirectory
(
normalization
)
add_subdirectory
(
data_type
)
add_subdirectory
(
data_type
)
add_subdirectory
(
wmma_op
)
test/wmma_op/CMakeLists.txt
0 → 100644
View file @
36c38ad9
add_test_executable
(
test_wmma_op wmma_op.cpp
)
target_link_libraries
(
test_wmma_op PRIVATE utility
)
test/wmma_op/wmma_op.cpp
0 → 100644
View file @
36c38ad9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/amd_wmma.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace
ck
{
__global__
void
matmul
(
const
half_t
*
a
,
const
half_t
*
b
,
float
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
// 16x16 matrix tile
half16_t
a_frag
=
{};
half16_t
b_frag
=
{};
// initialize c fragment to 0
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>
c_thread_buf_
;
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
// follow origin design
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
lane
+
ele
];
}
// sync threads, similar to mma_sync
__syncthreads
();
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>::
Run
(
a_frag
,
b_frag
,
c_thread_buf_
.
GetVectorTypeReference
(
Number
<
0
>
{}));
__syncthreads
();
// wait for results, similar to mma_sync
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
r
=
ele
*
2
+
(
lIdx
/
16
);
// store results from unpacked c_thread_buf_ output
c
[
16
*
r
+
lane
]
=
c_thread_buf_
[
Number
<
ele
>
{}];
});
}
__global__
void
matmul_swizzle_a
(
const
half_t
*
a
,
const
half_t
*
b
,
float
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
half16_t
a_frag
=
{};
half16_t
b_frag
=
{};
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>
c_thread_buf_
;
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
const
int
offset_m
=
(((
lane
&
1
)
<<
3
)
|
(
lane
>>
1
));
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
offset_m
+
ele
];
}
__syncthreads
();
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>::
Run
(
a_frag
,
b_frag
,
c_thread_buf_
.
GetVectorTypeReference
(
Number
<
0
>
{}));
__syncthreads
();
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
blk
=
lIdx
/
16
;
const
int
r
=
ele
;
c
[
16
*
8
*
blk
+
16
*
r
+
lane
]
=
c_thread_buf_
[
Number
<
ele
>
{}];
});
}
}
// namespace ck
int
main
(
int
,
char
*
[])
{
std
::
vector
<
float
>
host_a
(
16
*
16
);
std
::
vector
<
float
>
host_b
(
16
*
16
);
std
::
vector
<
float
>
host_c
(
16
*
16
);
std
::
vector
<
float
>
wmma_c
(
16
*
16
);
std
::
vector
<
float
>
wmma_c_swizzle_a
(
16
*
16
);
uint64_t
num_element
=
256
;
// generate matrix a
for
(
int
i_m
=
0
;
i_m
<
16
;
i_m
++
)
{
for
(
int
i_k
=
0
;
i_k
<
16
;
i_k
++
)
{
host_a
[
i_m
*
16
+
i_k
]
=
float
(
i_m
+
1
)
/
99.0
+
(
float
(
i_k
+
1
)
/
100
);
// host_a[i_m * 16 + i_k] = float(i_k);
}
}
// generate matrix b
for
(
int
i_n
=
0
;
i_n
<
16
;
i_n
++
)
{
for
(
int
i_k
=
0
;
i_k
<
16
;
i_k
++
)
{
host_b
[
i_n
*
16
+
i_k
]
=
float
(
i_n
+
1
)
/
98.0
+
(
float
(
i_k
+
1
)
/
100
);
// host_b[i_n * 16 + i_k] = 1.0;
}
}
// run mk_nk_mn gemm on cpu
for
(
int
i_m
=
0
;
i_m
<
16
;
i_m
++
)
{
for
(
int
i_n
=
0
;
i_n
<
16
;
i_n
++
)
{
for
(
int
i_k
=
0
;
i_k
<
16
;
i_k
++
)
{
host_c
[
i_m
*
16
+
i_n
]
+=
host_a
[
i_m
*
16
+
i_k
]
*
host_b
[
i_n
*
16
+
i_k
];
}
}
}
DeviceMem
device_a
(
sizeof
(
ck
::
half_t
)
*
num_element
);
DeviceMem
device_b
(
sizeof
(
ck
::
half_t
)
*
num_element
);
DeviceMem
device_c
(
sizeof
(
float
)
*
num_element
);
std
::
vector
<
ck
::
half_t
>
fp16_a
(
16
*
16
);
std
::
vector
<
ck
::
half_t
>
fp16_b
(
16
*
16
);
// convert fp32 a and b into fp16 on host
for
(
int
i
=
0
;
i
<
16
*
16
;
i
++
)
{
fp16_a
[
i
]
=
__float2half_rn
(
host_a
[
i
]);
fp16_b
[
i
]
=
__float2half_rn
(
host_b
[
i
]);
}
device_a
.
ToDevice
(
fp16_a
.
data
());
device_b
.
ToDevice
(
fp16_b
.
data
());
// run single wave wmma on GPU
ck
::
matmul
<<<
1
,
32
>>>
(
static_cast
<
const
ck
::
half_t
*>
(
device_a
.
GetDeviceBuffer
()),
static_cast
<
const
ck
::
half_t
*>
(
device_b
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
device_c
.
GetDeviceBuffer
()));
device_c
.
FromDevice
(
wmma_c
.
data
());
bool
res
=
ck
::
utils
::
check_err
(
wmma_c
,
host_c
,
"Error: Incorrect results!"
,
1e-2
);
// run single wave wmma_swizzle_a on GPU
ck
::
matmul_swizzle_a
<<<
1
,
32
>>>
(
static_cast
<
const
ck
::
half_t
*>
(
device_a
.
GetDeviceBuffer
()),
static_cast
<
const
ck
::
half_t
*>
(
device_b
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
device_c
.
GetDeviceBuffer
()));
device_c
.
FromDevice
(
wmma_c_swizzle_a
.
data
());
bool
res_swizzle_a
=
ck
::
utils
::
check_err
(
wmma_c_swizzle_a
,
host_c
,
"Error: Incorrect results!"
,
1e-2
);
if
(
res
&&
res_swizzle_a
)
{
std
::
cout
<<
"test single wave wmma: Pass"
<<
std
::
endl
;
return
0
;
}
else
{
std
::
cout
<<
"test single wave wmma: Fail"
<<
std
::
endl
;
return
-
1
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment