Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b1f7f365
Commit
b1f7f365
authored
Mar 25, 2021
by
Jing Zhang
Browse files
add int8x8_t
parent
8753e615
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
125 additions
and
20 deletions
+125
-20
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
...sable_kernel/include/utility/amd_buffer_addressing_v2.hpp
+10
-1
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+31
-0
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+82
-17
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
View file @
b1f7f365
...
@@ -142,7 +142,8 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -142,7 +142,8 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t
src_wave_addr_offset
)
index_t
src_wave_addr_offset
)
{
{
static_assert
((
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
static_assert
((
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
)),
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32x2_t
>::
value
&&
(
N
==
1
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
if
constexpr
(
is_same
<
T
,
float
>::
value
)
...
@@ -205,6 +206,14 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -205,6 +206,14 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return
tmp
.
Vector
();
return
tmp
.
Vector
();
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
int32x2_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
}
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
b1f7f365
...
@@ -215,5 +215,36 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
...
@@ -215,5 +215,36 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
#endif
#endif
}
}
__device__
void
amd_assembly_outer_product_1x4
(
int8x8_t
a
,
int8x8_t
b0
,
int8x8_t
b1
,
int8x8_t
b2
,
int8x8_t
b3
,
int32_t
&
c0
,
int32_t
&
c1
,
int32_t
&
c2
,
int32_t
&
c3
)
{
amd_assembly_outer_product_1x4
(
a
.
Vectors
(
Number
<
4
>
{})[
Number
<
0
>
{}],
b0
.
Vectors
(
Number
<
4
>
{})[
Number
<
0
>
{}],
b1
.
Vectors
(
Number
<
4
>
{})[
Number
<
0
>
{}],
b2
.
Vectors
(
Number
<
4
>
{})[
Number
<
0
>
{}],
b3
.
Vectors
(
Number
<
4
>
{})[
Number
<
0
>
{}],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
a
.
Vectors
(
Number
<
4
>
{})[
Number
<
1
>
{}],
b0
.
Vectors
(
Number
<
4
>
{})[
Number
<
1
>
{}],
b1
.
Vectors
(
Number
<
4
>
{})[
Number
<
1
>
{}],
b2
.
Vectors
(
Number
<
4
>
{})[
Number
<
1
>
{}],
b3
.
Vectors
(
Number
<
4
>
{})[
Number
<
1
>
{}],
c0
,
c1
,
c2
,
c3
);
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
b1f7f365
...
@@ -168,6 +168,27 @@ struct vector_type<T, 8>
...
@@ -168,6 +168,27 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
};
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
template <>
template <>
struct vector_type<int8_t, 2>
struct vector_type<int8_t, 2>
{
{
...
@@ -250,31 +271,61 @@ struct vector_type<int8_t, 4>
...
@@ -250,31 +271,61 @@ struct vector_type<int8_t, 4>
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
};
};
// fp32
template <>
using float2_t = typename vector_type<float, 2>::type;
struct vector_type<int8_t, 8>
using float4_t = typename vector_type<float, 4>::type;
{
using float8_t = typename vector_type<float, 8>::type;
using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
typedef int32x2_t d8_t;
// fp16
using type = d8_t;
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16
union
using ushort2_t = typename vector_type<ushort, 2>::type;
{
using ushort4_t = typename vector_type<ushort, 4>::type;
d8_t d8_;
using ushort8_t = typename vector_type<ushort, 8>::type;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
// i32
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
using int32x8_t = typename vector_type<int32_t, 8>::type;
__host__ __device__ static constexpr index_t Size() { return 8; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d8_; }
__host__ __device__ constexpr auto& Vector() { return data_.d8_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
};
// i8
// i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
// int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = vector_type<int8_t, 8>;
// data type conversion
// data type conversion
template <typename T>
template <typename T>
...
@@ -339,6 +390,20 @@ struct inner_product_with_conversion
...
@@ -339,6 +390,20 @@ struct inner_product_with_conversion
return acc;
return acc;
}
}
__device__ T operator()(int8x8_t a, int8x8_t b) const
{
const vector_type<int8_t, 8> a_vector{a};
const vector_type<int8_t, 8> b_vector{b};
T acc = 0;
static_for<0, 8, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
};
};
} // namespace ck
} // namespace ck
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
b1f7f365
...
@@ -113,7 +113,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -113,7 +113,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
4
;
constexpr
index_t
EPerBlock
=
2
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
...
...
driver/src/conv_driver.cpp
View file @
b1f7f365
...
@@ -642,7 +642,7 @@ int main(int argc, char* argv[])
...
@@ -642,7 +642,7 @@ int main(int argc, char* argv[])
using
out_data_t
=
int8_t
;
using
out_data_t
=
int8_t
;
#elif 1
#elif 1
using
in_data_t
=
int8_t
;
using
in_data_t
=
int8_t
;
constexpr
index_t
in_vector_size
=
4
;
constexpr
index_t
in_vector_size
=
8
;
using
acc_data_t
=
int32_t
;
using
acc_data_t
=
int32_t
;
using
out_data_t
=
int8_t
;
using
out_data_t
=
int8_t
;
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment