Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
76bb51f4
Commit
76bb51f4
authored
Mar 09, 2024
by
Jing Zhang
Browse files
merge
parents
92d58a8b
56a67231
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
9 deletions
+96
-9
example/01_gemm/gemm_xdl_fp16_fp8.cpp
example/01_gemm/gemm_xdl_fp16_fp8.cpp
+8
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+87
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+0
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+1
-1
No files found.
example/01_gemm/gemm_xdl_fp16_fp8.cpp
View file @
76bb51f4
...
@@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...
@@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeType
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeType
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
ComputeType
>
;
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
...
...
example/01_gemm/run_gemm_example.inc
View file @
76bb51f4
...
@@ -5,6 +5,88 @@
...
@@ -5,6 +5,88 @@
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
ProblemType
>
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
...
@@ -256,8 +338,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -256,8 +338,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
1
e
-
1
,
1
e
-
1
);
c_m_n_host_result
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
#endif
#endif
}
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
76bb51f4
...
@@ -607,7 +607,6 @@ struct BlockwiseGemmWMMA
...
@@ -607,7 +607,6 @@ struct BlockwiseGemmWMMA
A_K1
>
;
A_K1
>
;
};
};
#if 0
template
<
>
template
<
>
struct
AThreadCopySelector
<
false
>
struct
AThreadCopySelector
<
false
>
{
{
...
@@ -622,7 +621,6 @@ struct BlockwiseGemmWMMA
...
@@ -622,7 +621,6 @@ struct BlockwiseGemmWMMA
5
,
5
,
A_K1
>
;
A_K1
>
;
};
};
#endif
template
<
bool
EnableLds
>
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
struct
BThreadCopySelector
;
...
@@ -646,7 +644,6 @@ struct BlockwiseGemmWMMA
...
@@ -646,7 +644,6 @@ struct BlockwiseGemmWMMA
B_K1
>
;
B_K1
>
;
};
};
#if 0
template
<
>
template
<
>
struct
BThreadCopySelector
<
false
>
struct
BThreadCopySelector
<
false
>
{
{
...
@@ -661,7 +658,6 @@ struct BlockwiseGemmWMMA
...
@@ -661,7 +658,6 @@ struct BlockwiseGemmWMMA
5
,
5
,
B_K1
>
;
B_K1
>
;
};
};
#endif
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
76bb51f4
...
@@ -98,7 +98,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -98,7 +98,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
static
constexpr
auto
AEnableLds
=
tru
e
;
// AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
fals
e
;
// AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static
constexpr
auto
BEnableLds
=
static
constexpr
auto
BEnableLds
=
true
;
// BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
true
;
// BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment