Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
790e21ec
Commit
790e21ec
authored
Oct 28, 2022
by
aska-0096
Browse files
Refactor + Add all type unit test(int4 compile failed)
parent
049cc8af
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
428 additions
and
195 deletions
+428
-195
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+19
-31
test/wmma_op/wmma_op.cpp
test/wmma_op/wmma_op.cpp
+52
-164
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+357
-0
No files found.
include/ck/utility/amd_wmma.hpp
View file @
790e21ec
...
@@ -41,58 +41,51 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
...
@@ -41,58 +41,51 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
};
};
// src: fp16, dst: fp16
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
;
struct
intrin_wmma_f16_16x16x16_f16_w32
;
template
<
>
template
<
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
>
struct
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
,
Opsel
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
opsel
)
{
{
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
o
psel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
O
psel
);
}
}
};
};
// src: bf16, dst: bf
32
// src: bf16, dst: bf
16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
;
struct
intrin_wmma_bf16_16x16x16_bf16_w32
;
template
<
>
template
<
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
,
Opsel
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
opsel
)
{
{
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
o
psel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
O
psel
);
}
}
};
};
// src: iu8, dst: i32
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
;
struct
intrin_wmma_i32_16x16x16_iu8_w32
;
template
<
>
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bool
neg_a
,
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
const
int8x16_t
&
reg_a
,
const
bool
neg_b
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
clamp
)
{
{
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
...
@@ -107,19 +100,14 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16>
...
@@ -107,19 +100,14 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// src: iu4, dst: i32
// src: iu4, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu4_w32
;
struct
intrin_wmma_i32_16x16x16_iu4_w32
;
template
<
>
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu4_w32
<
16
,
16
>
struct
intrin_wmma_i32_16x16x16_iu4_w32
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bool
neg_a
,
__device__
static
void
Run
(
const
int4x16_t
&
reg_a
,
const
int4x16_t
&
reg_b
,
FloatC
&
reg_c
)
const
int4x16_t
&
reg_a
,
const
bool
neg_b
,
const
int4x16_t
&
reg_b
,
FloatC
&
reg_c
,
const
bool
clamp
)
{
{
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32
(
__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32
(
...
...
test/wmma_op/wmma_op.cpp
View file @
790e21ec
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
#include <
initializer_list
>
#include <
tuple
>
#include <
cstdlib
>
#include <
vector
>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/amd_wmma.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "test/wmma_op/wmma_op_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
template
<
typename
SrcType
,
typename
DstType
,
namespace
ck
{
typename
GPUAccType
,
__global__
void
matmul
(
const
half_t
*
a
,
const
half_t
*
b
,
float
*
c
)
typename
CPUAccType
,
ck
::
index_t
AccNum
>
bool
run_test
()
{
{
const
int
lIdx
=
threadIdx
.
x
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
bool
pass
=
true
;
// 16x16 matrix tile
half16_t
a_frag
=
{};
const
auto
matmul_default
=
ck
::
wmma_op_util
::
matmul
<
SrcType
,
DstType
,
GPUAccType
,
AccNum
>
;
half16_t
b_frag
=
{};
const
auto
matmul_swizzle_a
=
// initialize c fragment to 0
ck
::
wmma_op_util
::
matmul_swizzle_a
<
SrcType
,
DstType
,
GPUAccType
,
AccNum
>
;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>
c_thread_buf_
;
const
auto
wmma_kernel_container
=
std
::
make_tuple
(
matmul_default
,
matmul_swizzle_a
);
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
pass
&=
const
int
lane
=
lIdx
%
16
;
ck
::
wmma_op_util
::
TestWmma
<
decltype
(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
wmma_kernel_container
)),
SrcType
,
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
SrcType
,
{
DstType
,
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
GPUAccType
,
}
CPUAccType
,
// follow origin design
decltype
(
Row
{}),
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
decltype
(
Col
{}),
{
decltype
(
Row
{}),
a_frag
[
ele
]
=
a
[
16
*
lane
+
ele
];
PassThrough
,
}
PassThrough
,
PassThrough
,
// sync threads, similar to mma_sync
AccNum
>
{}(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
wmma_kernel_container
));
__syncthreads
();
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>::
Run
(
a_frag
,
b_frag
,
c_thread_buf_
.
GetVectorTypeReference
(
Number
<
0
>
{}));
__syncthreads
();
// wait for results, similar to mma_sync
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
r
=
ele
*
2
+
(
lIdx
/
16
);
// store results from unpacked c_thread_buf_ output
c
[
16
*
r
+
lane
]
=
c_thread_buf_
[
Number
<
ele
>
{}];
});
});
}
__global__
void
matmul_swizzle_a
(
const
half_t
*
a
,
const
half_t
*
b
,
float
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
half16_t
a_frag
=
{};
half16_t
b_frag
=
{};
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>
c_thread_buf_
;
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
const
int
offset_m
=
(((
lane
&
1
)
<<
3
)
|
(
lane
>>
1
));
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
offset_m
+
ele
];
}
__syncthreads
();
return
pass
?
1
:
0
;
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>::
Run
(
a_frag
,
b_frag
,
c_thread_buf_
.
GetVectorTypeReference
(
Number
<
0
>
{}));
__syncthreads
();
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
blk
=
lIdx
/
16
;
const
int
r
=
ele
;
c
[
16
*
8
*
blk
+
16
*
r
+
lane
]
=
c_thread_buf_
[
Number
<
ele
>
{}];
});
}
}
}
// namespace ck
int
main
(
int
,
char
*
[])
int
main
(
int
,
char
*
[])
{
{
std
::
vector
<
float
>
host_a
(
16
*
16
);
bool
pass
=
true
;
std
::
vector
<
float
>
host_b
(
16
*
16
);
// clang-format off
std
::
vector
<
float
>
host_c
(
16
*
16
);
// |SrcType |DstType |GPUAccType |CPUAccType |AccNum
std
::
vector
<
float
>
wmma_c
(
16
*
16
);
pass
&=
run_test
<
ck
::
half_t
,
float
,
float
,
float
,
8
>
();
std
::
vector
<
float
>
wmma_c_swizzle_a
(
16
*
16
);
pass
&=
run_test
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
16
>
();
uint64_t
num_element
=
256
;
pass
&=
run_test
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
16
>
();
pass
&=
run_test
<
int8_t
,
int8_t
,
int32_t
,
int32_t
,
8
>
();
// generate matrix a
// clang-format on
for
(
int
i_m
=
0
;
i_m
<
16
;
i_m
++
)
{
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
for
(
int
i_k
=
0
;
i_k
<
16
;
i_k
++
)
return
pass
?
0
:
1
;
{
host_a
[
i_m
*
16
+
i_k
]
=
float
(
i_m
+
1
)
/
99.0
+
(
float
(
i_k
+
1
)
/
100
);
// host_a[i_m * 16 + i_k] = float(i_k);
}
}
// generate matrix b
for
(
int
i_n
=
0
;
i_n
<
16
;
i_n
++
)
{
for
(
int
i_k
=
0
;
i_k
<
16
;
i_k
++
)
{
host_b
[
i_n
*
16
+
i_k
]
=
float
(
i_n
+
1
)
/
98.0
+
(
float
(
i_k
+
1
)
/
100
);
// host_b[i_n * 16 + i_k] = 1.0;
}
}
// run mk_nk_mn gemm on cpu
for
(
int
i_m
=
0
;
i_m
<
16
;
i_m
++
)
{
for
(
int
i_n
=
0
;
i_n
<
16
;
i_n
++
)
{
for
(
int
i_k
=
0
;
i_k
<
16
;
i_k
++
)
{
host_c
[
i_m
*
16
+
i_n
]
+=
host_a
[
i_m
*
16
+
i_k
]
*
host_b
[
i_n
*
16
+
i_k
];
}
}
}
DeviceMem
device_a
(
sizeof
(
ck
::
half_t
)
*
num_element
);
DeviceMem
device_b
(
sizeof
(
ck
::
half_t
)
*
num_element
);
DeviceMem
device_c
(
sizeof
(
float
)
*
num_element
);
std
::
vector
<
ck
::
half_t
>
fp16_a
(
16
*
16
);
std
::
vector
<
ck
::
half_t
>
fp16_b
(
16
*
16
);
// convert fp32 a and b into fp16 on host
for
(
int
i
=
0
;
i
<
16
*
16
;
i
++
)
{
fp16_a
[
i
]
=
__float2half_rn
(
host_a
[
i
]);
fp16_b
[
i
]
=
__float2half_rn
(
host_b
[
i
]);
}
device_a
.
ToDevice
(
fp16_a
.
data
());
device_b
.
ToDevice
(
fp16_b
.
data
());
// run single wave wmma on GPU
ck
::
matmul
<<<
1
,
32
>>>
(
static_cast
<
const
ck
::
half_t
*>
(
device_a
.
GetDeviceBuffer
()),
static_cast
<
const
ck
::
half_t
*>
(
device_b
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
device_c
.
GetDeviceBuffer
()));
device_c
.
FromDevice
(
wmma_c
.
data
());
// run single wave wmma_swizzle_a on GPU
ck
::
matmul_swizzle_a
<<<
1
,
32
>>>
(
static_cast
<
const
ck
::
half_t
*>
(
device_a
.
GetDeviceBuffer
()),
static_cast
<
const
ck
::
half_t
*>
(
device_b
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
device_c
.
GetDeviceBuffer
()));
device_c
.
FromDevice
(
wmma_c_swizzle_a
.
data
());
// result check
bool
res
=
true
;
bool
res_swizzle_a
=
true
;
res
=
ck
::
utils
::
check_err
(
wmma_c
,
host_c
,
"Error: Incorrect results!"
,
1e-2
);
res_swizzle_a
=
ck
::
utils
::
check_err
(
wmma_c_swizzle_a
,
host_c
,
"Error: Incorrect results!"
,
1e-2
);
if
(
res
&&
res_swizzle_a
)
{
std
::
cout
<<
"test single wave wmma: Pass"
<<
std
::
endl
;
return
0
;
}
else
{
std
::
cout
<<
"test single wave wmma: Fail"
<<
std
::
endl
;
return
-
1
;
}
}
}
test/wmma_op/wmma_op_util.hpp
0 → 100644
View file @
790e21ec
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace
ck
{
namespace
wmma_op_util
{
template
<
typename
src_vec
,
typename
acc_vec
>
__device__
void
builtin_wmma_naive_selector
(
const
src_vec
&
,
const
src_vec
&
,
acc_vec
&
)
{
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
half16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>>
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
half16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
half_t
,
1
,
16
,
true
>>
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
half_t
,
1
,
16
,
true
>&
reg_c
)
{
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
,
0
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
bhalf16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
bhalf_t
,
1
,
16
,
true
>>
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
bhalf_t
,
1
,
16
,
true
>&
reg_c
)
{
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
,
0
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
int8x16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>>
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
,
true
,
true
,
false
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
__device__
void
builtin_wmma_naive_selector
<
int4x16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>>
(
const
int4x16_t
&
reg_a
,
const
int4x16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_i32_16x16x16_iu4_w32
<
16
,
16
,
true
,
true
,
false
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
#endif
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
// 16x16 matrix tile
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
src_vec
a_frag
=
{};
src_vec
b_frag
=
{};
// initialize c fragment to 0
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
acc_vec
c_thread_buf_
;
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
// follow origin design
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
lane
+
ele
];
}
// sync threads, similar to mma_sync
__syncthreads
();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
__syncthreads
();
// wait for results, similar to mma_sync
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
r
=
ele
*
2
+
(
lIdx
/
16
);
// store results from unpacked c_thread_buf_ output
c
[
16
*
r
+
lane
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
ele
*
acc_num
/
8
>
{}]);
});
}
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul_swizzle_a
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
src_vec
a_frag
=
{};
src_vec
b_frag
=
{};
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
acc_vec
c_thread_buf_
;
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
const
int
offset_m
=
(((
lane
&
1
)
<<
3
)
|
(
lane
>>
1
));
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
offset_m
+
ele
];
}
__syncthreads
();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
__syncthreads
();
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
blk
=
lIdx
/
16
;
const
int
r
=
ele
;
c
[
16
*
8
*
blk
+
16
*
r
+
lane
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
ele
*
acc_num
/
8
>
{}]);
});
}
struct
GemmParams
{
GemmParams
()
:
M
(
16
),
N
(
16
),
K
(
16
),
StrideA
(
16
),
StrideB
(
16
),
StrideC
(
16
),
alpha
(
1
),
beta
(
0
)
{}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
float
alpha
;
float
beta
;
};
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
kernel
<<<
1
,
32
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
template
<
typename
DeviceWmma
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GPUAccDataType
,
typename
CPUAccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
CAccNum
>
struct
TestWmma
{
auto
PrepareGemmTensor
(
const
ck
::
wmma_op_util
::
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_n_k
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
using
dataType
=
decltype
(
type
);
tensor
.
GenerateTensorValue
(
GeneratorTensor_2
<
dataType
>
{
-
5
,
5
});
};
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
b_n_k
,
BDataType
{});
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceWmma
&
wmma_kernel
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
ck
::
wmma_op_util
::
GemmParams
params
;
params
.
M
=
16
;
params
.
N
=
16
;
params
.
K
=
16
;
params
.
StrideA
=
16
;
params
.
StrideB
=
16
;
params
.
StrideC
=
16
;
auto
host_tensors
=
PrepareGemmTensor
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
CPUAccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
ck
::
wmma_op_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
bool
is_supported
=
ck
::
wmma_op_util
::
RunDeviceGEMM
(
wmma_kernel
,
a
,
b
,
c_device
);
if
(
is_supported
)
{
// Assert
bool
res
=
false
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
,
"Error: Incorrect results!"
,
0
,
1.0
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"UNSUPPORTED CDataType"
<<
std
::
endl
;
}
return
res
;
}
else
{
return
true
;
}
}
};
}
// namespace wmma_op_util
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment