Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0ff1d1f8
Commit
0ff1d1f8
authored
Aug 31, 2023
by
Bartlomiej Wroblewski
Browse files
Review: Remove hardcoded datatypes
parent
7b7dd69d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
39 deletions
+87
-39
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+39
-15
include/ck/utility/amd_gemm_dpp.hpp
include/ck/utility/amd_gemm_dpp.hpp
+48
-24
No files found.
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
0ff1d1f8
...
@@ -11,9 +11,9 @@ namespace ck {
...
@@ -11,9 +11,9 @@ namespace ck {
enum
struct
DppInstr
enum
struct
DppInstr
{
{
dpp8_16x16x2
=
0
,
dpp8_
f16_
16x16x2
=
0
,
dpp8_8x32x2
,
dpp8_
f16_
8x32x2
,
dpp8_32x8x2
dpp8_
f16_
32x8x2
};
};
/**
/**
...
@@ -42,7 +42,7 @@ template <DppInstr instr>
...
@@ -42,7 +42,7 @@ template <DppInstr instr>
struct
dpp_type
;
struct
dpp_type
;
template
<
>
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_32x8x2
>
struct
dpp_type
<
DppInstr
::
dpp8_
f16_
32x8x2
>
{
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
lanegroup_size
=
8
;
...
@@ -54,17 +54,25 @@ struct dpp_type<DppInstr::dpp8_32x8x2>
...
@@ -54,17 +54,25 @@ struct dpp_type<DppInstr::dpp8_32x8x2>
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
static
constexpr
bool
share_a
=
true
;
using
base_type
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
dpp8
::
RunGemm
<
m_per_lanegroup
,
n_per_lanegroup
,
k_per_dpp
,
FloatA
,
FloatB
,
FloatC
,
share_a
>
(
dpp8
::
DppInstrRunner
<
m_per_thread
,
a
,
b
,
reg_c
);
n_per_thread
,
k_per_dpp
,
base_type
,
FloatA
,
FloatB
,
FloatC
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
}
};
};
template
<
>
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_8x32x2
>
struct
dpp_type
<
DppInstr
::
dpp8_
f16_
8x32x2
>
{
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
lanegroup_size
=
8
;
...
@@ -76,17 +84,25 @@ struct dpp_type<DppInstr::dpp8_8x32x2>
...
@@ -76,17 +84,25 @@ struct dpp_type<DppInstr::dpp8_8x32x2>
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
static
constexpr
bool
share_a
=
true
;
using
base_type
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
dpp8
::
RunGemm
<
m_per_lanegroup
,
n_per_lanegroup
,
k_per_dpp
,
FloatA
,
FloatB
,
FloatC
,
share_a
>
(
dpp8
::
DppInstrRunner
<
m_per_thread
,
a
,
b
,
reg_c
);
n_per_thread
,
k_per_dpp
,
base_type
,
FloatA
,
FloatB
,
FloatC
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
}
};
};
template
<
>
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_16x16x2
>
struct
dpp_type
<
DppInstr
::
dpp8_
f16_
16x16x2
>
{
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
lanegroup_size
=
8
;
...
@@ -98,12 +114,20 @@ struct dpp_type<DppInstr::dpp8_16x16x2>
...
@@ -98,12 +114,20 @@ struct dpp_type<DppInstr::dpp8_16x16x2>
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
static
constexpr
bool
share_a
=
true
;
using
base_type
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
dpp8
::
RunGemm
<
m_per_lanegroup
,
n_per_lanegroup
,
k_per_dpp
,
FloatA
,
FloatB
,
FloatC
,
share_a
>
(
dpp8
::
DppInstrRunner
<
m_per_thread
,
a
,
b
,
reg_c
);
n_per_thread
,
k_per_dpp
,
base_type
,
FloatA
,
FloatB
,
FloatC
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -116,19 +140,19 @@ struct DppSelector
...
@@ -116,19 +140,19 @@ struct DppSelector
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
{
return
DppInstr
::
dpp8_8x32x2
;
return
DppInstr
::
dpp8_
f16_
8x32x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
{
return
DppInstr
::
dpp8_16x16x2
;
return
DppInstr
::
dpp8_
f16_
16x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
{
return
DppInstr
::
dpp8_32x8x2
;
return
DppInstr
::
dpp8_
f16_
32x8x2
;
}
}
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
base_type
,
MPerDpp
,
NPerDpp
>
()
>
{};
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
base_type
,
MPerDpp
,
NPerDpp
>
()
>
{};
...
...
include/ck/utility/amd_gemm_dpp.hpp
View file @
0ff1d1f8
...
@@ -11,33 +11,57 @@ namespace ck {
...
@@ -11,33 +11,57 @@ namespace ck {
namespace
dpp8
{
namespace
dpp8
{
template
<
index_t
MPerLanegroup
,
template
<
class
ABDataType
>
index_t
NPerLanegroup
,
struct
dpp_datatypes
;
index_t
KPerLanegroup
,
class
FloatA
,
template
<
>
class
FloatB
,
struct
dpp_datatypes
<
half_t
>
class
FloatVecC
,
{
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using
a_dtype
=
half_t
;
using
b_dtype
=
half_t
;
using
c_dtype
=
float
;
static
constexpr
index_t
k_per_instr
=
2
;
};
template
<
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
class
BaseInputType
,
class
AVecDataType
,
class
BVecDataType
,
class
CVecDataType
,
bool
ShareA
>
bool
ShareA
>
__device__
void
RunGemm
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatVecC
&
c_vec
)
struct
DppInstrRunner
{
{
constexpr
index_t
c_dim
=
ShareA
?
MPerLanegroup
:
NPerLanegroup
;
static
constexpr
auto
datatypes_conf
=
dpp_datatypes
<
BaseInputType
>
{};
using
ADataType
=
typename
decltype
(
datatypes_conf
)
::
a_dtype
;
const
vector_type
<
half_t
,
KPerLanegroup
>
a_vector
{
a
};
using
BDataType
=
typename
decltype
(
datatypes_conf
)
::
b_dtype
;
const
vector_type
<
half_t
,
KPerLanegroup
>
b_vector
{
b
};
using
CDataType
=
typename
decltype
(
datatypes_conf
)
::
c_dtype
;
static_for
<
0
,
c_dim
,
1
>
{}([
&
](
auto
c_idx
)
{
__device__
void
Run
(
const
AVecDataType
&
a_vec
,
const
BVecDataType
&
b_vec
,
CVecDataType
&
c_vec
)
float
c
=
c_vec
.
template
AsType
<
float
>()(
c_idx
);
{
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr
index_t
num_c_elems_per_thread
=
ShareA
?
MPerThread
:
NPerThread
;
constexpr
index_t
source_lane
=
c_idx
;
static_for
<
0
,
KPerLanegroup
/
2
,
1
>
{}([
&
](
auto
k_chunk
)
{
const
vector_type
<
ADataType
,
KPerThread
>
a_vector
{
a_vec
};
const
auto
a_half2
=
a_vector
.
template
AsType
<
half2_t
>()[
k_chunk
];
const
vector_type
<
BDataType
,
KPerThread
>
b_vector
{
b_vec
};
const
auto
b_half2
=
b_vector
.
template
AsType
<
half2_t
>()[
k_chunk
];
ck
::
dpp8
::
inner_product_dpp
<
half2_t
,
half2_t
,
float
,
source_lane
,
ShareA
>
(
static_for
<
0
,
num_c_elems_per_thread
,
1
>
{}([
&
](
auto
c_idx
)
{
a_half2
,
b_half2
,
c
);
float
c
=
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr
index_t
source_lane
=
c_idx
;
static_for
<
0
,
KPerThread
/
datatypes_conf
.
k_per_instr
,
1
>
{}([
&
](
auto
k_chunk
)
{
const
auto
a_k_vec
=
a_vector
.
template
AsType
<
AVecDataType
>()[
k_chunk
];
const
auto
b_k_vec
=
b_vector
.
template
AsType
<
BVecDataType
>()[
k_chunk
];
ck
::
dpp8
::
inner_product_dpp
<
AVecDataType
,
BVecDataType
,
CDataType
,
source_lane
,
ShareA
>
(
a_k_vec
,
b_k_vec
,
c
);
});
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
)
=
c
;
});
});
c_vec
.
template
AsType
<
float
>()(
c_idx
)
=
c
;
}
});
};
}
}
// namespace dpp8
}
// namespace dpp8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment