Commit 8edbc659 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 04c5527d
...@@ -577,10 +577,10 @@ int main(int argc, char* argv[]) ...@@ -577,10 +577,10 @@ int main(int argc, char* argv[])
ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
Tensor<float> in_nchw(make_TensorDescriptor(in_nchw_desc)); Tensor<half> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<float> wei_kcsr(make_TensorDescriptor(wei_kcsr_desc)); Tensor<half> wei_kcsr(make_TensorDescriptor(wei_kcsr_desc));
Tensor<float> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); Tensor<half> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<float> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); Tensor<half> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -633,7 +633,7 @@ int main(int argc, char* argv[]) ...@@ -633,7 +633,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 1 #if 0
if(Y == 3 && X == 3) if(Y == 3 && X == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
......
...@@ -10,8 +10,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, ...@@ -10,8 +10,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
SrcOpLengths, SrcOpLengths,
Number<DataPerRead>) Number<DataPerRead>)
{ {
using Float2 = float2; using vector_t = typename vector_type<Float, DataPerRead>::type;
using Float4 = float4;
static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 && static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 &&
SrcOpLengths::nDim == 6, SrcOpLengths::nDim == 6,
...@@ -62,24 +61,8 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, ...@@ -62,24 +61,8 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
const unsigned dst_index = dst_desc.Get1dIndex( const unsigned dst_index = dst_desc.Get1dIndex(
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
if(DataPerRead == 1) *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
{ *(reinterpret_cast<const vector_t*>(p_src + src_index));
p_dst[dst_index] = p_src[src_index];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float2*>(p_src + src_index));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float4*>(p_src + src_index));
}
else
{
assert(false);
}
} }
} }
} }
...@@ -97,8 +80,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, ...@@ -97,8 +80,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
SrcOpLengths, SrcOpLengths,
Number<DataPerRead>) Number<DataPerRead>)
{ {
using Float2 = float2; using vector_t = typename vector_type<Float, DataPerRead>::type;
using Float4 = float4;
static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 && static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 &&
SrcOpLengths::nDim == 8, SrcOpLengths::nDim == 8,
...@@ -169,24 +151,8 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, ...@@ -169,24 +151,8 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
did6, did6,
iloop_d7 * DataPerRead); iloop_d7 * DataPerRead);
if(DataPerRead == 1) *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
{ *(reinterpret_cast<const vector_t*>(p_src + src_index));
p_dst[dst_index] = p_src[src_index];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float2*>(p_src + src_index));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float4*>(p_src + src_index));
}
else
{
assert(false);
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment