Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f5ae909b
"...git@developer.sourcefind.cn:wangsen/megatron-lm.git" did not exist on "7c19b3a8ff1f6961e6aaec283a9dcf261a51efac"
Commit
f5ae909b
authored
Jan 11, 2022
by
Jing Zhang
Browse files
add bfl16 buildins
parent
41fb383f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
53 deletions
+12
-53
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+6
-47
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+6
-6
No files found.
composable_kernel/include/utility/amd_xdlops.hpp
View file @
f5ae909b
...
@@ -5,45 +5,6 @@
...
@@ -5,45 +5,6 @@
namespace
ck
{
namespace
ck
{
// A, B, C, cbsz, abid, blgp
// bfp16
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k
(
ushort4_t
,
ushort4_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x8bf16.1k"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k
(
ushort4_t
,
ushort4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x16bf16.1k"
);
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
ushort2_t
,
ushort2_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2bf16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x4bf16
(
ushort2_t
,
ushort2_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4bf16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x8bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x8bf16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x2bf16
(
ushort2_t
,
ushort2_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x2bf16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
// int8
extern
"C"
__device__
int32x32_t
llvm_intrin_amdgcn_mfma_i32_32x32x4i8
(
int
,
int
,
int32x32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.32x32x4i8"
);
extern
"C"
__device__
int32x16_t
llvm_intrin_amdgcn_mfma_i32_16x16x4i8
(
int
,
int
,
int32x16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.16x16x4i8"
);
extern
"C"
__device__
int32x4_t
llvm_intrin_amdgcn_mfma_i32_4x4x4i8
(
int
,
int
,
int32x4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.4x4x4i8"
);
extern
"C"
__device__
int32x16_t
llvm_intrin_amdgcn_mfma_i32_32x32x8i8
(
int
,
int
,
int32x16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.32x32x8i8"
);
extern
"C"
__device__
int32x4_t
llvm_intrin_amdgcn_mfma_i32_16x16x16i8
(
int
,
int
,
int32x4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.i32.16x16x16i8"
);
// fp32
// fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x1f32
;
struct
intrin_mfma_f32_32x32x1f32
;
...
@@ -248,8 +209,7 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
...
@@ -248,8 +209,7 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort4_t
&
reg_a
,
const
ushort4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
ushort4_t
&
reg_a
,
const
ushort4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
@@ -263,8 +223,7 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
...
@@ -263,8 +223,7 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort4_t
&
reg_a
,
const
ushort4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
ushort4_t
&
reg_a
,
const
ushort4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
@@ -278,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
...
@@ -278,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_32x32x4bf16
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_32x32x4bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
@@ -292,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
...
@@ -292,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_
16x16x8
bf16
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_
32x32x4
bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
f5ae909b
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
...
@@ -250,15 +250,15 @@ int main(int argc, char* argv[])
...
@@ -250,15 +250,15 @@ int main(int argc, char* argv[])
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
#endif
#endif
#if
1
#if
0
using in_data_t = float;
using in_data_t = float;
using acc_data_t = float;
using acc_data_t = float;
using out_data_t = float;
using out_data_t = float;
#elif
1
#elif
0
using
in_data_t
=
half_t
;
using
in_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
using
out_data_t
=
half_t
;
#elif
0
#elif
1
using
in_data_t
=
ushort
;
using
in_data_t
=
ushort
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
ushort
;
using
out_data_t
=
ushort
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment