Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aaf3d81d
Commit
aaf3d81d
authored
Dec 16, 2020
by
Jing Zhang
Browse files
clean code
parent
1d6022b1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
134 additions
and
161 deletions
+134
-161
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+0
-28
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+3
-2
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
...sor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
+6
-12
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+106
-113
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+4
-4
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+15
-2
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
aaf3d81d
...
...
@@ -55,34 +55,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm
.
GetOutputLayout
();
}
#if CK_WORKAROUND_SWDEV_241664
template
<
index_t
MRepeats_
=
MRepeats
,
index_t
NRepeats_
=
NRepeats
>
__device__
constexpr
auto
CreateOutputVecZero
()
const
;
template
<
>
__device__
constexpr
auto
CreateOutputVecZero
<
2
,
1
>
()
const
{
return
c_vec32_2_2_t
::
CreateVecZero
();
}
template
<
>
__device__
constexpr
auto
CreateOutputVecZero
<
1
,
2
>
()
const
{
return
c_vec32_2_2_t
::
CreateVecZero
();
}
template
<
>
__device__
constexpr
auto
CreateOutputVecZero
<
1
,
1
>
()
const
{
return
XdlopsGemm
.
GetOutputLayout
().
CreateOutputVecZero
();
}
#else
__device__
constexpr
auto
CreateOutputVecZero
()
const
{
return
XdlopsGemm
.
GetOutputLayout
().
CreateOutputVecZero
();
}
#endif
__device__
constexpr
auto
GetNumBlks
()
const
{
#if CK_WORKAROUND_SWDEV_241664
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
View file @
aaf3d81d
...
...
@@ -210,7 +210,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
// get zero-initialized output register of vector type
// auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
auto
c_thread_vec
=
float_vec128_t
{};
constexpr
index_t
c_thread_size
=
MPerBlock
*
NPerBlock
/
BlockSize
;
auto
c_thread_vec
=
GetRegBuffer
<
AccFloat
,
c_thread_size
>
();
// preload data into LDS
{
...
...
@@ -325,7 +326,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
))
.
Run
(
c_thread_vec
.
At
(
Number
<
16
>
{})[
Number
<
blk_id
>
{}],
p_c_global
);
.
Store
(
c_thread_vec
.
At
(
Number
<
M0
*
M2
>
{})[
Number
<
blk_id
>
{}],
p_c_global
);
});
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
View file @
aaf3d81d
...
...
@@ -84,9 +84,9 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
template
<
typename
DstData
,
typename
SrcData
>
__device__
static
void
load_data
(
DstData
&
dst
,
const
SrcData
*
p_src
,
index_t
src_offset
)
__device__
static
auto
load_data
(
const
SrcData
*
p_src
,
index_t
src_offset
)
{
dst
=
*
reinterpret_cast
<
const
DstData
*>
(
&
p_src
[
src_offset
]);
return
*
reinterpret_cast
<
const
DstData
*>
(
&
p_src
[
src_offset
]);
}
template
<
typename
DstData
,
typename
SrcData
>
...
...
@@ -104,9 +104,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template
<
typename
SrcCoord
>
__device__
static
auto
run
(
const
float
*
p_src
,
const
SrcCoord
src_coord_begin
)
{
float
r
;
load_data
(
r
,
p_src
,
src_coord_begin
.
GetOffset
());
return
r
;
return
load_data
<
float
>
(
p_src
,
src_coord_begin
.
GetOffset
());
}
};
...
...
@@ -116,9 +114,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template
<
typename
SrcCoord
>
__device__
static
auto
run
(
const
float
*
p_src
,
const
SrcCoord
src_coord_begin
)
{
float2_t
r
;
load_data
(
r
,
p_src
,
src_coord_begin
.
GetOffset
());
return
r
;
return
load_data
<
float2_t
>
(
p_src
,
src_coord_begin
.
GetOffset
());
}
};
...
...
@@ -128,9 +124,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template
<
typename
SrcCoord
>
__device__
static
auto
run
(
const
float
*
p_src
,
const
SrcCoord
src_coord_begin
)
{
float4_t
r
;
load_data
(
r
,
p_src
,
src_coord_begin
.
GetOffset
());
return
r
;
return
load_data
<
float4_t
>
(
p_src
,
src_coord_begin
.
GetOffset
());
}
};
...
...
@@ -237,7 +231,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
SrcData
src
,
DstData
*
p_dst
)
__device__
void
Store
(
SrcData
src
,
DstData
*
p_dst
)
{
constexpr
auto
vector_access_dim
=
Number
<
DstVectorWriteDim
>
{};
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
aaf3d81d
...
...
@@ -15,13 +15,13 @@ enum struct mfma_instr
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
// k reduction
// fp16
// fp16
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_16x16x16f16
,
// k reduction
// bfp16
// bfp16
mfma_f32_32x32x2bf16
,
mfma_f32_16x16x2bf16
,
mfma_f32_4x4x2bf16
,
...
...
@@ -535,8 +535,7 @@ template <mfma_instr instr,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
index_t
MRepeats_
,
index_t
NRepeats_
,
class
OutputVecType_
>
index_t
NRepeats_
>
struct
xdlops_info
{
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
...
...
@@ -552,8 +551,6 @@ struct xdlops_info
{
return
(
mfma_type
.
num_output_blks
==
1
)
&&
(
mfma_type
.
num_input_blks
>
1
);
}
static
constexpr
auto
OutputVecType
=
OutputVecType_
{};
};
template
<
class
data_type
,
...
...
@@ -635,27 +632,59 @@ struct XdlopsGemm_t
}
}
}
}).
Else
([
&
](
auto
)
{
static_if
<
IsABroadcast
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
{
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
})
.
Else
([
&
](
auto
)
{
static_if
<
IsABroadcast
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
{
// ABroadcast
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
{
// ABroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
MPerXdlops
/
mfma_type
.
m
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
b
*
mfma_type
.
m
+
MPerXdlops
*
m_i
;
index_t
b_off
=
k
*
N
+
n
*
mfma_type
.
num_threads_blk
+
NPerXdlops
*
n_i
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
+
(
NRepeats
*
m_i
+
n_i
)
*
GetRegSizePerXdlops
();
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
}
}
})
.
Else
([
&
](
auto
)
{
// BBroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
M
PerXdlops
/
mfma_type
.
m
;
++
b
)
for
(
index_t
b
=
0
;
b
<
N
PerXdlops
/
mfma_type
.
n
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
b
*
mfma_type
.
m
+
MPerXdlops
*
m_i
;
index_t
b_off
=
k
*
N
+
n
*
mfma_type
.
num_threads_blk
+
NPerXdlops
*
n_i
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
+
(
NRepeats
*
m_i
+
n_i
)
*
GetRegSizePerXdlops
();
index_t
a_off
=
k
*
M
+
n
*
mfma_type
.
m
;
index_t
b_off
=
k
*
N
+
b
*
mfma_type
.
n
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
...
...
@@ -672,37 +701,8 @@ struct XdlopsGemm_t
}
}
}
}
}
}).
Else
([
&
](
auto
)
{
// BBroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
NPerXdlops
/
mfma_type
.
n
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
n
*
mfma_type
.
m
;
index_t
b_off
=
k
*
N
+
b
*
mfma_type
.
n
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
});
});
});
return
p_c_thread
;
}
...
...
@@ -745,13 +745,12 @@ struct XdlopsGemm_t
constexpr
index_t
BStride
=
K
*
KRepeats
;
static_if
<!
IsKReduction
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
for
(
index_t
k_i
=
0
;
k_i
<
K
;
++
k_i
)
for
(
index_t
k_i
=
0
;
k_i
<
K
;
++
k_i
)
a
[
k_i
+
m_i
*
K
]
=
p_a_wave
[
k_i
*
M
+
laneId
+
MPerXdlops
*
m_i
];
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
for
(
index_t
k_i
=
0
;
k_i
<
K
;
++
k_i
)
for
(
index_t
k_i
=
0
;
k_i
<
K
;
++
k_i
)
b
[
k_i
+
n_i
*
K
]
=
p_b_wave
[
k_i
*
N
+
laneId
+
NPerXdlops
*
n_i
];
#if CK_WORKAROUND_SWDEV_229564
...
...
@@ -765,32 +764,31 @@ struct XdlopsGemm_t
BStride
>(
&
pa
[
k_i
*
mfma_type
.
k_base
],
&
pb
[
k_i
*
mfma_type
.
k_base
],
p_c_thread
);
}
})
.
Else
([
&
](
auto
)
{
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
}).
Else
([
&
](
auto
)
{
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
// load into registers
for
(
index_t
k_i
=
0
;
k_i
<
K
;
k_i
+=
mfma_type
.
num_input_blks
)
{
a
[
k_i
]
=
p_a_wave
[(
k_i
+
blk_id
)
*
M
+
blk_td
];
b
[
k_i
]
=
p_b_wave
[(
k_i
+
blk_id
)
*
N
+
blk_td
];
}
// load into registers
for
(
index_t
k_i
=
0
;
k_i
<
K
;
k_i
+=
mfma_type
.
num_input_blks
)
{
a
[
k_i
]
=
p_a_wave
[(
k_i
+
blk_id
)
*
M
+
blk_td
];
b
[
k_i
]
=
p_b_wave
[(
k_i
+
blk_id
)
*
N
+
blk_td
];
}
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for
(
index_t
k_i
=
0
;
k_i
<
K
;
k_i
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
i
=
0
;
i
<
KRepeats
;
++
i
)
p_c_thread
=
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>(
&
pa
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
&
p
b
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
p_c_thread
);
}
});
for
(
index_t
k_i
=
0
;
k_i
<
K
;
k_i
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
i
=
0
;
i
<
KRepeats
;
++
i
)
p_c_thread
=
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>(
&
p
a
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
&
pb
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
p_c_thread
);
}
});
#endif
return
p_c_thread
;
...
...
@@ -837,199 +835,199 @@ struct XdlopsGemm_t
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
2
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
2
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
32
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
64
,
16
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
2
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
2
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
32
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
32
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
64
,
16
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
16
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
8
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
4
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x8f16
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x8f16
,
32
,
32
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x16f16
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x16f16
,
16
,
16
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
2
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
2
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
32
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
32
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
64
,
16
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
16
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
8
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
4
,
64
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
,
32
,
32
,
1
,
1
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
,
16
,
16
,
1
,
1
>
{};
}
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
...
...
@@ -1055,11 +1053,6 @@ struct XdlopsGemm_t
{
return
GetNumBlksPerXdlops
()
*
MRepeats
*
NRepeats
;
}
__device__
static
constexpr
auto
CreateOutputVecZero
()
{
return
GetXdlopsInfo
().
OutputVecType
.
CreateVecZero
();
}
};
__device__
static
constexpr
auto
GetOutputLayout
()
{
return
OutputLayout
{};
}
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
aaf3d81d
...
...
@@ -95,10 +95,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{
__device__
static
float_vec64_t
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
float_vec64_t
reg_c
)
{
reg_c
.
At
(
Number
<
32
>
{})
(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
At
(
Number
<
32
>
{})
[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
At
(
Number
<
32
>
{})
(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
At
(
Number
<
32
>
{})
[
Number
<
1
>
{}],
1
,
1
,
0
);
reg_c
.
v32
(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
v32
[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
v32
(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
v32
[
Number
<
1
>
{}],
1
,
1
,
0
);
return
reg_c
;
}
};
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
aaf3d81d
...
...
@@ -186,7 +186,8 @@ union float_vec32_t
union float_vec64_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64;
__host__ __device__ constexpr float_vec64_t() {}
...
...
@@ -210,7 +211,7 @@ union float_vec128_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 8> s16;
StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float
_vec
32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
__host__ __device__ constexpr float_vec128_t() {}
...
...
@@ -264,6 +265,18 @@ constexpr auto GetRegBuffer<float, 16>()
return float_vec16_t{};
}
template <>
constexpr auto GetRegBuffer<float, 64>()
{
return float_vec64_t{};
}
template <>
constexpr auto GetRegBuffer<float, 128>()
{
return float_vec128_t{};
}
struct c_vec32_4_t
{
union VecType
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment