Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fa6d8037
Commit
fa6d8037
authored
Jul 29, 2019
by
Tejash Shah
Browse files
Changed scalar type to vector type for threadwise gemm for fp16 and bfloat16 data types
parent
2185affb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
65 deletions
+105
-65
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
...t_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
+49
-7
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+29
-54
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+0
-1
driver/src/driver.cpp
driver/src/driver.cpp
+27
-3
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
fa6d8037
...
@@ -308,9 +308,23 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
...
@@ -308,9 +308,23 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
half
,
4
>::
MemoryType
;
vector_t
*
vec_wei_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_wei_block_now
);
vector_t
*
vec_in_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_in_block_now
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
vec_wei_block_now
,
vec_in_block_now
,
p_out_thread
);
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
ushort
,
2
>::
MemoryType
;
vector_t
*
vec_wei_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_wei_block_now
);
vector_t
*
vec_in_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_in_block_now
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
vec_wei_block_now
,
vec_in_block_now
,
p_out_thread
);
});
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
p_in_block_next
);
...
@@ -336,7 +350,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
...
@@ -336,7 +350,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
half
,
4
>::
MemoryType
;
vector_t
*
vec_wei_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_wei_block_double
);
vector_t
*
vec_in_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_in_block_double
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
vec_wei_block_now
,
vec_in_block_now
,
p_out_thread
);
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
ushort
,
2
>::
MemoryType
;
vector_t
*
vec_wei_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_wei_block_double
);
vector_t
*
vec_in_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_in_block_double
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
vec_wei_block_now
,
vec_in_block_now
,
p_out_thread
);
});
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
...
@@ -348,9 +377,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
...
@@ -348,9 +377,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
p_in_block_double
+
in_block_space
,
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
p_out_thread
);
using
vector_t
=
typename
vector_type
<
half
,
4
>::
MemoryType
;
vector_t
*
vec_wei_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_wei_block_double
+
wei_block_space
);
vector_t
*
vec_in_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_in_block_double
+
in_block_space
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
vec_wei_block_now
,
vec_in_block_now
,
p_out_thread
);
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
ushort
,
2
>::
MemoryType
;
vector_t
*
vec_wei_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_wei_block_double
+
wei_block_space
);
vector_t
*
vec_in_block_now
=
reinterpret_cast
<
vector_t
*>
(
p_in_block_double
+
in_block_space
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
vec_wei_block_now
,
vec_in_block_now
,
p_out_thread
);
});
}
}
// copy output: register to global memory
// copy output: register to global memory
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
fa6d8037
...
@@ -57,37 +57,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
...
@@ -57,37 +57,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
}
}
}).
Else
([
&
](
auto
)
{
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
// For half/bfloat16, Float type is half4/bfloat2 respectively.
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
4
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
4
]);
}
}
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
Float
,
2
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
{
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
2
])
=
*
reinterpret_cast
<
Float
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
2
]);
*
reinterpret_cast
<
const
Float
*>
(
&
p_src
[
src_index
]);
}
}
}
}
);
}
});
});
}
}
...
@@ -129,32 +110,26 @@ __device__ void threadwise_gemm(MatrixA,
...
@@ -129,32 +110,26 @@ __device__ void threadwise_gemm(MatrixA,
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
>
{}([
&
](
auto
)
{
//static_if<std::is_same<FloatA, float>::value>{}([&](auto) {
p_c_thread
[
cindex
]
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
])
*
// p_c_thread[cindex] +=
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
]);
// CVT_FLOAT2ACCUM(p_a_thread[aindex]) * CVT_FLOAT2ACCUM(p_b_thread[bindex]);
}).
Else
([
&
](
auto
)
{
//}).Else([&](auto) {
static_if
<
std
::
is_same
<
FloatA
,
half
>::
value
>
{}([
&
](
auto
)
{
static_if
<
std
::
is_same
<
FloatA
,
ck
::
vector_type
<
half
,
4
>::
MemoryType
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
const
half
*
s0_half
=
reinterpret_cast
<
const
half
*>
(
&
p_a_thread
[
aindex
]);
// respectively)
const
half
*
s1_half
=
reinterpret_cast
<
const
half
*>
(
&
p_b_thread
[
bindex
]);
float
acc
=
0.0
;
p_c_thread
[
cindex
]
+=
for
(
index_t
v
=
0
;
v
<
4
;
++
v
)
CVT_FLOAT2ACCUM
(
s0_half
[
0
])
*
CVT_FLOAT2ACCUM
(
s1_half
[
0
])
+
{
CVT_FLOAT2ACCUM
(
s0_half
[
1
])
*
CVT_FLOAT2ACCUM
(
s1_half
[
1
])
+
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
4
+
v
])
*
CVT_FLOAT2ACCUM
(
s0_half
[
2
])
*
CVT_FLOAT2ACCUM
(
s1_half
[
2
])
+
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
4
+
v
]);
CVT_FLOAT2ACCUM
(
s0_half
[
3
])
*
CVT_FLOAT2ACCUM
(
s1_half
[
3
]);
}
}).
Else
([
&
](
auto
)
{
p_c_thread
[
cindex
]
=
acc
;
const
ushort
*
s0_ushort
=
reinterpret_cast
<
const
ushort
*>
(
&
p_a_thread
[
aindex
]);
}).
Else
([
&
](
auto
)
{
const
ushort
*
s1_ushort
=
reinterpret_cast
<
const
ushort
*>
(
&
p_b_thread
[
bindex
]);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
p_c_thread
[
cindex
]
+=
// respectively)
CVT_FLOAT2ACCUM
(
s0_ushort
[
0
])
*
CVT_FLOAT2ACCUM
(
s1_ushort
[
0
])
+
float
acc
=
0.0
;
CVT_FLOAT2ACCUM
(
s0_ushort
[
1
])
*
CVT_FLOAT2ACCUM
(
s1_ushort
[
1
]);
for
(
index_t
v
=
0
;
v
<
2
;
++
v
)
});
{
// });
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
2
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
2
+
v
]);
}
p_c_thread
[
cindex
]
+=
acc
;
});
});
}
}
}
}
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
fa6d8037
...
@@ -112,7 +112,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
...
@@ -112,7 +112,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
//printf("%f ", static_cast<float>(p_dst[dst_index]));
}).
Else
([
&
](
auto
)
{
}).
Else
([
&
](
auto
)
{
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
{
{
...
...
driver/src/driver.cpp
View file @
fa6d8037
...
@@ -138,11 +138,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
...
@@ -138,11 +138,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
{
v
+=
double
(
in_nchw
(
n
,
c
,
hi
,
wi
))
*
double
(
wei_kcyx
(
k
,
c
,
y
,
x
));
v
+=
double
(
in_nchw
(
n
,
c
,
hi
,
wi
))
*
double
(
wei_kcyx
(
k
,
c
,
y
,
x
));
if
(
n
==
0
&&
k
==
0
&&
ho
==
0
&&
wo
==
0
)
{
//std::cout << "cpu " << c << "," << hi << "," << wi << " * " <<
// << c << "," << y << "," << x << " = "
// << in_nchw(n,c,hi,wi) << " * " << wei_kcyx(k, c, y, x) << std::endl;
// printf(" cpu %d,%d,%d * %d,%d,%d = %f * %f\n",
// c, hi, wi, c, y, x, double(in_nchw(n,c,hi,wi)), double(wei_kcyx(k, c, y, x)));
}
}
}
}
}
}
}
}
}
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
v
;
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
v
;
if
(
n
==
0
&&
k
==
0
&&
ho
==
0
&&
wo
==
0
)
printf
(
"cpu %d,%d,%d,%d = %f"
,
n
,
k
,
ho
,
wo
,
v
);
};
};
auto
f_par
=
make_ParallelTensorFunctor
(
f
,
auto
f_par
=
make_ParallelTensorFunctor
(
f
,
...
@@ -787,9 +797,8 @@ int main(int argc, char* argv[])
...
@@ -787,9 +797,8 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 1x1 filter, 7x7 image
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr
index_t
N
=
32
;
constexpr
index_t
N
=
32
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
28
;
...
@@ -803,6 +812,20 @@ int main(int argc, char* argv[])
...
@@ -803,6 +812,20 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 1
constexpr
index_t
N
=
8
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
4
;
constexpr
index_t
WI
=
4
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#endif
#endif
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
...
@@ -897,7 +920,7 @@ int main(int argc, char* argv[])
...
@@ -897,7 +920,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
#if
1
#if
0
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{
{
...
@@ -915,6 +938,7 @@ int main(int argc, char* argv[])
...
@@ -915,6 +938,7 @@ int main(int argc, char* argv[])
upper_pads
);
upper_pads
);
}
}
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
printf
(
"gpu value %f"
,
double
(
out_nkhw_device
.
mData
[
0
]));
#if 0
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment