Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
90ec6a19
"...composable_kernel.git" did not exist on "f6cb5b846d1eff1d1e35ab58273becfd40bd0831"
Commit
90ec6a19
authored
May 19, 2021
by
Jing Zhang
Browse files
added 128x128 wavegemm
parent
1d48b521
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
429 additions
and
207 deletions
+429
-207
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+45
-45
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+117
-143
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+8
-6
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+246
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+13
-13
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
90ec6a19
...
@@ -160,11 +160,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -160,11 +160,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
m0
,
m0
,
n0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
n0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
});
});
});
});
});
});
...
@@ -372,18 +372,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -372,18 +372,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
a_thread_buf
);
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
0
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
static_for
<
KPerWave
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
static_for
<
KPerWave
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
// read A_sub_0
// read A_sub_0
...
@@ -395,11 +395,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -395,11 +395,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
a_thread_buf
);
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// read B_sub_0
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
BBlockDesc
{},
...
@@ -410,11 +410,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -410,11 +410,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
b_thread_buf
);
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// read B_sub_1
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
BBlockDesc
{},
...
@@ -433,33 +433,33 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -433,33 +433,33 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
a_thread_buf
);
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
0
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
});
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm
.
template
Run
2
<
decltype
(
a_thread_desc_
),
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
}
}
private:
private:
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
90ec6a19
...
@@ -518,7 +518,7 @@ template <mfma_instr instr,
...
@@ -518,7 +518,7 @@ template <mfma_instr instr,
index_t
NPerXdlops_
,
index_t
NPerXdlops_
,
index_t
MRepeats_
,
index_t
MRepeats_
,
index_t
NRepeats_
,
index_t
NRepeats_
,
class
OutputVec
Type_
>
class
C
Type_
>
struct
xdlops_info
struct
xdlops_info
{
{
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
...
@@ -540,196 +540,74 @@ struct xdlops_info
...
@@ -540,196 +540,74 @@ struct xdlops_info
return
mfma_type
.
k_base
*
(
IsKReduction
()
?
mfma_type
.
num_input_blks
:
1
);
return
mfma_type
.
k_base
*
(
IsKReduction
()
?
mfma_type
.
num_input_blks
:
1
);
}
}
static
constexpr
auto
OutputVecType
=
OutputVecType_
{};
static
constexpr
index_t
GetNumCRegs
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
static
constexpr
auto
GetCType
()
{
return
CType_
{};
}
};
};
template
<
class
data
_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
>
template
<
class
base
_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
>
struct
XdlopsGemm
struct
XdlopsGemm
{
{
struct
MatrixIndex
template
<
class
base_type_
=
base_type
,
{
index_t
row
;
index_t
col
;
};
__device__
static
constexpr
index_t
GetNumBlksPerXdlops
()
{
return
(
MPerXdlops
*
NPerXdlops
)
/
(
mfma_type
.
m
*
mfma_type
.
n
);
}
__host__
__device__
constexpr
XdlopsGemm
()
{
static_assert
(
NPerXdlops
==
4
||
NPerXdlops
==
8
||
NPerXdlops
==
16
||
NPerXdlops
==
32
||
NPerXdlops
==
64
,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
MPerXdlops
==
4
||
MPerXdlops
==
8
||
MPerXdlops
==
16
||
MPerXdlops
==
32
||
MPerXdlops
==
64
,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
"m != num_input_blks * num_regs_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m
*
mfma_type
.
n
,
"num_regs_blk incorrect"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k % kbase != 0!"
);
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
data_type
,
float
>::
value
||
is_same
<
data_type
,
half_t
>::
value
||
is_same
<
data_type
,
ushort
>::
value
,
"base data_type must be float, half, ushort!"
);
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k_i
)
{
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
k_i
>
{}],
p_b_wave
[
Number
<
k_i
>
{}],
p_c_thread
);
});
}
template
<
class
ADesc
,
class
BDesc
,
class
CDesc
,
index_t
m0
,
index_t
n0
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run2
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
data_type
,
float
>::
value
||
is_same
<
data_type
,
half_t
>::
value
||
is_same
<
data_type
,
ushort
>::
value
,
"base data_type must be float, half, ushort!"
);
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_multi_index
(
k
,
m0
,
0
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_multi_index
(
k
,
n0
,
0
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_multi_index
(
m0
,
n0
));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
.
template
AsType
<
float32_t
>());
});
}
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
{
const
index_t
xdlops_i
=
i
/
GetNumBlksPerXdlops
();
const
index_t
j
=
i
%
GetNumBlksPerXdlops
();
const
index_t
m_i
=
xdlops_i
/
NRepeats
;
const
index_t
n_i
=
xdlops_i
%
NRepeats
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
index_t
col_blk
=
j
%
mfma_type
.
num_output_blks
;
index_t
row_blk
=
j
/
mfma_type
.
num_output_blks
;
static_if
<!
IsABroadcast
>
{}([
&
](
auto
)
{
col_blk
=
j
/
mfma_type
.
num_output_blks
;
row_blk
=
j
%
mfma_type
.
num_output_blks
;
});
index_t
col
=
col_blk
*
mfma_type
.
n
+
blk_td
+
n_i
*
NPerXdlops
;
index_t
row
=
row_blk
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
+
m_i
*
MPerXdlops
;
return
MatrixIndex
{
row
,
col
};
}
__device__
void
SetZeroXdlopsRegs
()
const
{}
template
<
class
FloatC
>
__device__
void
ReadXdlopsRegs
(
FloatC
*
const
__restrict__
)
const
{
}
template
<
class
data_type_
=
data_type
,
index_t
MPerWave_
=
MPerWave
,
index_t
MPerWave_
=
MPerWave
,
index_t
NPerWave_
=
NPerWave
>
index_t
NPerWave_
=
NPerWave
>
static
constexpr
auto
GetXdlopsInfo
();
static
constexpr
auto
GetXdlopsInfo
();
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
1
,
c_vec32_2
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
1
,
float64
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
32
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
32
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
32
,
1
,
1
,
c_vec32_1
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
32
,
1
,
1
,
float32
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
,
1
,
1
,
c_vec32_1
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
,
1
,
1
,
float32
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
16
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
16
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
64
,
16
,
1
,
1
,
c_vec16_1
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
64
,
16
,
1
,
1
,
float16
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
,
1
,
1
,
c_vec16_1
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
,
1
,
1
,
float16
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
8
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
8
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
,
1
,
1
,
c_vec4_2
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
,
1
,
1
,
float8
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
4
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
4
,
64
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
,
1
,
1
,
c_vec4_1
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
,
1
,
1
,
float4
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
32
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
32
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
,
1
,
1
,
c_vec16_1
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
,
1
,
1
,
float16
_t
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
16
>
()
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
16
>
()
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
,
1
,
1
,
c_vec4_1
_t
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
,
1
,
1
,
float4
_t
>
{};
}
}
#if 0
template <>
template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{
{
...
@@ -861,6 +739,107 @@ struct XdlopsGemm
...
@@ -861,6 +739,107 @@ struct XdlopsGemm
{
{
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
}
}
#endif
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
__device__
static
constexpr
index_t
GetNumBlksPerXdlops
()
{
return
(
MPerXdlops
*
NPerXdlops
)
/
(
mfma_type
.
m
*
mfma_type
.
n
);
}
__host__
__device__
constexpr
XdlopsGemm
()
{
static_assert
(
NPerXdlops
==
4
||
NPerXdlops
==
8
||
NPerXdlops
==
16
||
NPerXdlops
==
32
||
NPerXdlops
==
64
,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
MPerXdlops
==
4
||
MPerXdlops
==
8
||
MPerXdlops
==
16
||
MPerXdlops
==
32
||
MPerXdlops
==
64
,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
"m != num_input_blks * num_regs_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m
*
mfma_type
.
n
,
"num_regs_blk incorrect"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k % kbase != 0!"
);
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
template
<
class
ADesc
,
class
BDesc
,
class
CDesc
,
index_t
m0
,
index_t
n0
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
ushort
>::
value
,
"base base_type must be float, half, ushort!"
);
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_multi_index
(
k
,
m0
,
0
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_multi_index
(
k
,
n0
,
0
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_multi_index
(
m0
,
n0
));
vector_type
<
base_type
,
GetXdlopsInfo
().
GetNumCRegs
()
>
t
;
using
c_type
=
decltype
(
GetXdlopsInfo
().
GetCType
());
t
.
template
AsType
<
c_type
>()(
Number
<
0
>
{})
=
p_c_thread
.
template
AsType
<
c_type
>()[
Number
<
c_offset
>
{}];
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
t
);
p_c_thread
.
template
AsType
<
c_type
>()(
Number
<
c_offset
>
{})
=
t
.
template
AsType
<
c_type
>()[
Number
<
0
>
{}];
});
}
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
{
const
index_t
xdlops_i
=
i
/
GetNumBlksPerXdlops
();
const
index_t
j
=
i
%
GetNumBlksPerXdlops
();
const
index_t
m_i
=
xdlops_i
/
NRepeats
;
const
index_t
n_i
=
xdlops_i
%
NRepeats
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
index_t
col_blk
=
j
%
mfma_type
.
num_output_blks
;
index_t
row_blk
=
j
/
mfma_type
.
num_output_blks
;
static_if
<!
IsABroadcast
>
{}([
&
](
auto
)
{
col_blk
=
j
/
mfma_type
.
num_output_blks
;
row_blk
=
j
%
mfma_type
.
num_output_blks
;
});
index_t
col
=
col_blk
*
mfma_type
.
n
+
blk_td
+
n_i
*
NPerXdlops
;
index_t
row
=
row_blk
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
+
m_i
*
MPerXdlops
;
return
MatrixIndex
{
row
,
col
};
}
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
static
constexpr
index_t
NRepeats
=
GetXdlopsInfo
().
NRepeats
;
static
constexpr
index_t
NRepeats
=
GetXdlopsInfo
().
NRepeats
;
...
@@ -896,11 +875,6 @@ struct XdlopsGemm
...
@@ -896,11 +875,6 @@ struct XdlopsGemm
{
{
return
GetNumBlksPerXdlops
()
*
MRepeats
*
NRepeats
;
return
GetNumBlksPerXdlops
()
*
MRepeats
*
NRepeats
;
}
}
__device__
static
constexpr
auto
CreateOutputVecZero
()
{
return
GetXdlopsInfo
().
OutputVecType
.
CreateVecZero
();
}
};
};
__host__
__device__
static
constexpr
auto
GetOutputLayout
()
{
return
OutputLayout
{};
}
__host__
__device__
static
constexpr
auto
GetOutputLayout
()
{
return
OutputLayout
{};
}
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
90ec6a19
...
@@ -243,10 +243,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
...
@@ -243,10 +243,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()
[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
1
>
{})
=
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
1
>
{}],
1
,
1
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()
[
Number
<
1
>
{}],
1
,
1
,
0
);
}
}
};
};
...
@@ -278,9 +278,11 @@ struct intrin_mfma_f32_32x32x2f32;
...
@@ -278,9 +278,11 @@ struct intrin_mfma_f32_32x32x2f32;
template
<
>
template
<
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
>
{
{
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
&
reg_c
)
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
,
0
,
0
,
0
);
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
90ec6a19
...
@@ -618,6 +618,252 @@ struct vector_type<T, 64>
...
@@ -618,6 +618,252 @@ struct vector_type<T, 64>
}
}
};
};
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
};
template <typename T>
struct vector_type<T, 256>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
};
// fp32
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float4_t = typename vector_type<float, 4>::type;
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
90ec6a19
...
@@ -102,30 +102,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -102,30 +102,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
#else
constexpr
index_t
BlockSize
=
6
4
;
constexpr
index_t
BlockSize
=
25
6
;
constexpr
index_t
GemmMPerBlock
=
6
4
;
constexpr
index_t
GemmMPerBlock
=
25
6
;
constexpr
index_t
GemmNPerBlock
=
6
4
;
constexpr
index_t
GemmNPerBlock
=
25
6
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPerWave
=
1
;
constexpr
index_t
GemmKPerWave
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
32
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment