Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5c27dcd5
Commit
5c27dcd5
authored
May 26, 2021
by
Jing Zhang
Browse files
add fp32 mfma instructions
parent
21755b5d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
132 additions
and
173 deletions
+132
-173
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+3
-2
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+1
-1
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+11
-11
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+22
-55
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+84
-94
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+11
-10
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
5c27dcd5
...
@@ -11,7 +11,8 @@ namespace ck {
...
@@ -11,7 +11,8 @@ namespace ck {
// GemmM = K
// GemmM = K
// GemmN = N * Ho * Wo
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
// GemmK = C * Y * X
template
<
index_t
GemmMPerBlock
,
template
<
typename
FloatAB
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmMPerWave
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmNPerWave
,
...
@@ -109,7 +110,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -109,7 +110,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
f
loat
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPerWave
>
{};
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
F
loat
AB
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPerWave
>
{};
constexpr
auto
CLayout
=
xdlops_gemm
.
GetCLayout
();
constexpr
auto
CLayout
=
xdlops_gemm
.
GetCLayout
();
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
5c27dcd5
...
@@ -34,7 +34,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -34,7 +34,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
f
loat
,
MPerWave
,
NPerWave
,
KPack
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
F
loat
A
,
MPerWave
,
NPerWave
,
KPack
>
{};
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
5c27dcd5
...
@@ -306,14 +306,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -306,14 +306,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
>
{}));
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
_2x2pipeline
<
BlockSize
,
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m0_m1_k1_block_desc
),
decltype
(
a_k0_m0_m1_k1_block_desc
),
decltype
(
b_k0_n0_n1_k1_block_desc
),
decltype
(
b_k0_n0_n1_k1_block_desc
),
MPerWave
,
MPerWave
,
NPerWave
,
NPerWave
,
KPack
>
{};
KPack
>
{};
constexpr
auto
CLayout
=
blockwise_gemm
.
GetCLayout
();
constexpr
auto
CLayout
=
blockwise_gemm
.
GetCLayout
();
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
...
@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
Number
<
NumBlks
>
{},
Number
<
BlkSize
>
{}));
make_tuple
(
Number
<
NumBlks
>
{},
Number
<
BlkSize
>
{}));
StaticBuffer
<
AddressSpace
::
Vgpr
,
StaticBuffer
<
AddressSpace
::
Vgpr
,
vector_type
<
f
loat
,
c_blk_nb_bs_desc
.
GetElementSpaceSize
()
>
,
vector_type
<
F
loat
AB
,
c_blk_nb_bs_desc
.
GetElementSpaceSize
()
>
,
c_mr_nr_nx_desc
.
GetElementSpaceSize
()
>
c_mr_nr_nx_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
c_thread_buf
;
...
@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
StaticBuffer
<
AddressSpace
::
Vgpr
,
f
loat
,
BlkSize
>
c_blk_buf_
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
F
loat
AB
,
BlkSize
>
c_blk_buf_
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mr_i
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mr_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nr_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nr_i
)
{
...
@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
mr_i
,
nr_i
,
xdlops_i
))
>
{}];
make_tuple
(
mr_i
,
nr_i
,
xdlops_i
))
>
{}];
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
c_blk_buf_
(
j
)
=
c_blk
.
template
AsType
<
f
loat
>()[
Number
<
c_blk_buf_
(
j
)
=
c_blk
.
template
AsType
<
F
loat
AB
>()[
Number
<
c_blk_nb_bs_desc
.
CalculateOffset
(
make_tuple
(
blk_i
,
j
))
>
{}];
c_blk_nb_bs_desc
.
CalculateOffset
(
make_tuple
(
blk_i
,
j
))
>
{}];
});
});
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
5c27dcd5
...
@@ -10,19 +10,19 @@ namespace ck {
...
@@ -10,19 +10,19 @@ namespace ck {
enum
struct
mfma_instr
enum
struct
mfma_instr
{
{
// fp32
//
/
fp32
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_16x16x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
// k reduction
// fp16
//
/
fp16
mfma_f32_32x32x4f16
,
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_16x16x16f16
,
// k reduction
mfma_f32_16x16x16f16
,
// k reduction
// bfp16
//
/
bfp16
mfma_f32_32x32x2bf16
,
mfma_f32_32x32x2bf16
,
mfma_f32_16x16x2bf16
,
mfma_f32_16x16x2bf16
,
mfma_f32_4x4x2bf16
,
mfma_f32_4x4x2bf16
,
...
@@ -58,7 +58,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
...
@@ -58,7 +58,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
class
FloatC
>
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -87,7 +87,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
...
@@ -87,7 +87,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
class
FloatC
>
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
return
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -110,17 +110,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
...
@@ -110,17 +110,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
COffset
,
index_t
BStride
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
intrin_mfma_f32_16x16x4f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_16x16x4f32
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -143,17 +139,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
...
@@ -143,17 +139,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
COffset
,
index_t
BStride
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -177,17 +169,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
...
@@ -177,17 +169,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
COffset
,
index_t
BStride
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -523,20 +511,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
...
@@ -523,20 +511,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
}
}
};
};
template
<
mfma_instr
instr
,
template
<
mfma_instr
instr
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
>
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
index_t
MRepeats_
,
index_t
NRepeats_
,
class
CType_
>
struct
xdlops_info
struct
xdlops_info
{
{
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
static
constexpr
index_t
MPerXdlops
=
MPerXdlops_
;
static
constexpr
index_t
MPerXdlops
=
MPerXdlops_
;
static
constexpr
index_t
NPerXdlops
=
NPerXdlops_
;
static
constexpr
index_t
NPerXdlops
=
NPerXdlops_
;
static
constexpr
index_t
MRepeats
=
MRepeats_
;
static
constexpr
index_t
NRepeats
=
NRepeats_
;
static
constexpr
bool
IsABroadcast
()
static
constexpr
bool
IsABroadcast
()
{
{
...
@@ -555,8 +536,6 @@ struct xdlops_info
...
@@ -555,8 +536,6 @@ struct xdlops_info
}
}
static
constexpr
index_t
GetNumCRegs
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
static
constexpr
index_t
GetNumCRegs
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
static
constexpr
auto
GetCType
()
{
return
CType_
{};
}
};
};
template
<
class
base_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPack
>
template
<
class
base_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPack
>
...
@@ -570,55 +549,43 @@ struct XdlopsGemm
...
@@ -570,55 +549,43 @@ struct XdlopsGemm
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
1
,
float64_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
32
,
1
,
1
,
float32_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
,
1
,
1
,
float32_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
64
,
16
,
1
,
1
,
float16_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
,
1
,
1
,
float16_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
8
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
8
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
,
1
,
1
,
float8_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
4
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
4
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
,
1
,
1
,
float4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
32
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
32
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
,
1
,
1
,
float16_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
16
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
16
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
,
1
,
1
,
float4_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
>
{};
}
}
#if 0
#if 0
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
5c27dcd5
...
@@ -201,42 +201,6 @@ extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
...
@@ -201,42 +201,6 @@ extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
;
struct
intrin_mfma_f32_32x32x1f32
;
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride>
//{
//__device__ static c_vec32_4_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
// reg_c.s.z =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
// reg_c.s.w =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
// return reg_c;
//}
//};
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
//{
//__device__ static c_vec32_4_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
// reg_c.s.z =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
// reg_c.s.w =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
// return reg_c;
//}
//};
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
COffset
>
{
{
...
@@ -262,27 +226,22 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
...
@@ -262,27 +226,22 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
}
}
};
};
// template <index_t AStride, index_t BStride>
template
<
index_t
COffset
>
// struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
COffset
>
//{
{
//__device__ static c_vec32_1_t::VecType
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
//{
{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
// return reg_c;
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
//}
reg_a
,
//};
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
// template <index_t AStride, index_t BStride>
1
,
// struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride>
0
,
//{
0
);
//__device__ static c_vec32_1_t::VecType
}
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
};
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// return reg_c;
//}
//};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_32x32x2f32
;
struct
intrin_mfma_f32_32x32x2f32
;
...
@@ -304,58 +263,89 @@ struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
...
@@ -304,58 +263,89 @@ struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
}
}
};
};
__device__
c_vec4_1_t
::
VecType
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
intrin_mfma_f32_16x16x4f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
struct
intrin_mfma_f32_16x16x4f32
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x4f32
<
16
,
16
,
COffset
>
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
return
reg_c
;
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
}
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
__device__
c_vec16_1_t
::
VecType
struct
intrin_mfma_f32_16x16x1f32
;
intrin_mfma_f32_16x16x1f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
template
<
index_t
COffset
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
*
reg_a
,
struct
intrin_mfma_f32_16x16x1f32
<
16
,
64
,
COffset
>
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
return
reg_c
;
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
}
{
template
<
>
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
*
reg_a
,
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
const
float
*
reg_b
,
reg_a
,
c_vec16_1_t
::
VecType
reg_c
)
reg_b
,
{
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
2
,
return
reg_c
;
0
,
}
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_4x4x1f32
;
struct
intrin_mfma_f32_4x4x1f32
;
template
<
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
,
COffset
>
{
{
__device__
static
c_vec4_1_t
::
VecType
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
r
un
(
const
f
loat
*
reg_a
,
const
f
loat
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
__device__
static
void
R
un
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
return
reg_c
;
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
}
}
};
};
template
<
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
,
COffset
>
{
{
__device__
static
c_vec4_2_t
::
VecType
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
r
un
(
const
f
loat
*
reg_a
,
const
f
loat
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
__device__
static
void
R
un
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
return
reg_c
;
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
1
,
0
);
}
}
};
};
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
5c27dcd5
...
@@ -49,6 +49,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -49,6 +49,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
static_assert
(
1
==
InWeiVectorSize
,
"support InWeiVectorSize == 1 only!"
);
#if 1
#if 1
// run-time variables
// run-time variables
const
auto
in_n_c_hi_wi_desc
=
const
auto
in_n_c_hi_wi_desc
=
...
@@ -108,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -108,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmMPerWave
=
8
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPerWave
=
4
;
constexpr
index_t
GemmKPerWave
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
MRepeat
=
8
;
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
...
@@ -131,7 +133,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -131,7 +133,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#endif
#endif
const
auto
descs
=
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
GemmMPerBlock
,
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
TInWei
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmMPerWave
,
GemmMPerWave
,
GemmNPerWave
,
GemmNPerWave
,
...
@@ -148,7 +151,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -148,7 +151,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
{
{
float
ave_time
=
launch_kernel_dynamic_gemm_xdlops_v1
<
float
ave_time
=
launch_kernel_dynamic_gemm_xdlops_v1
<
BlockSize
,
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TInWei
,
TAcc
,
TAcc
,
TOut
,
TOut
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
...
@@ -188,10 +191,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -188,10 +191,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
decltype
(
descs
[
I5
]),
decltype
(
descs
[
I5
]),
decltype
(
descs
[
I6
]),
decltype
(
descs
[
I6
]),
decltype
(
descs
[
I7
]),
decltype
(
descs
[
I7
]),
decltype
(
descs
[
I8
])
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
decltype
(
descs
[
I8
])
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
descs
[
I0
],
descs
[
I1
],
descs
[
I1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment