Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
06c9f9fe
Commit
06c9f9fe
authored
Oct 14, 2018
by
Chao Liu
Browse files
initial build
parent
fc98757a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
121 additions
and
53 deletions
+121
-53
CMakeLists.txt
CMakeLists.txt
+1
-0
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/include/tensor.hpp
src/include/tensor.hpp
+113
-48
src/tensor.cpp
src/tensor.cpp
+6
-4
No files found.
CMakeLists.txt
View file @
06c9f9fe
...
@@ -3,6 +3,7 @@ project(convolution LANGUAGES CXX CUDA)
...
@@ -3,6 +3,7 @@ project(convolution LANGUAGES CXX CUDA)
#c++
#c++
message
(
"CMAKE_CXX_COMPILER_ID:
${
CMAKE_CXX_COMPILER_ID
}
"
)
message
(
"CMAKE_CXX_COMPILER_ID:
${
CMAKE_CXX_COMPILER_ID
}
"
)
add_compile_options
(
-std=c++14
)
#boost
#boost
find_package
(
Boost REQUIRED
)
find_package
(
Boost REQUIRED
)
...
...
src/CMakeLists.txt
View file @
06c9f9fe
...
@@ -12,7 +12,7 @@ target_link_libraries(convolution boost_python3)
...
@@ -12,7 +12,7 @@ target_link_libraries(convolution boost_python3)
# cuda
# cuda
target_link_libraries
(
convolution nvToolsExt
)
target_link_libraries
(
convolution nvToolsExt
)
target_compile_features
(
convolution PUBLIC
cxx_std_11
)
target_compile_features
(
convolution PUBLIC
)
set_target_properties
(
convolution PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
convolution PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
convolution PROPERTIES CUDA_SEPARABLE_COMPILATION OFF
)
set_target_properties
(
convolution PROPERTIES CUDA_SEPARABLE_COMPILATION OFF
)
...
...
src/include/tensor.hpp
View file @
06c9f9fe
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#include <numeric>
#include <numeric>
#include <utility>
#include "cuda_runtime.h"
#include "helper_cuda.h"
typedef
enum
typedef
enum
{
{
...
@@ -34,17 +37,21 @@ struct TensorDescriptor
...
@@ -34,17 +37,21 @@ struct TensorDescriptor
this
->
CalculateStrides
();
this
->
CalculateStrides
();
}
}
template
<
class
Range1
,
class
Range2
>
template
<
class
Range1
,
class
Range2
>
TensorDescriptor
(
DataType_t
t
,
const
Range1
&
lens
,
const
Range2
&
strides
)
TensorDescriptor
(
DataType_t
t
,
const
Range1
&
lens
,
const
Range2
&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
()),
mDataType
(
t
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
()),
mDataType
(
t
)
{}
{
}
std
::
size_t
GetDimension
()
const
;
std
::
size_t
GetDimension
()
const
;
std
::
size_t
GetElementSize
()
const
;
std
::
size_t
GetElementSize
()
const
;
std
::
size_t
GetElementSpace
()
const
;
std
::
size_t
GetElementSpace
()
const
;
template
<
class
...
Xs
>
const
std
::
vector
<
std
::
size_t
>&
GetLengths
()
const
;
std
::
size_t
GetIndex
(
Xs
...
xs
)
const
const
std
::
vector
<
std
::
size_t
>&
GetStrides
()
const
;
template
<
class
...
Xs
>
std
::
size_t
Get1dIndex
(
Xs
...
xs
)
const
{
{
assert
(
sizeof
...(
Xs
)
==
this
->
GetDimension
());
assert
(
sizeof
...(
Xs
)
==
this
->
GetDimension
());
std
::
initializer_list
<
std
::
size_t
>
is
{
xs
...};
std
::
initializer_list
<
std
::
size_t
>
is
{
xs
...};
...
@@ -81,7 +88,49 @@ struct Tensor
...
@@ -81,7 +88,49 @@ struct Tensor
template
<
class
G
>
template
<
class
G
>
void
GenerateTensorValue
(
G
g
)
void
GenerateTensorValue
(
G
g
)
{
{
parallel_for
([
&
](
Xs
...
xs
)
{
mData
(
mDesc
.
GetIndex
(
xs
...))
=
g
(
xs
...);
},
mDesc
.
mLens
);
// ParallelTensorFunctor([&](Xs... xs) { mData(mDesc.Get1dIndex(xs...)) = g(xs...); },
// mDesc.mLens)();
switch
(
mDesc
.
GetDimension
())
{
case
1
:
{
ParallelTensorFunctor
([
&
](
auto
i
)
{
mData
(
mDesc
.
Get1dIndex
(
i
))
=
g
(
i
);
},
mDesc
.
GetLengths
()[
0
])();
break
;
}
case
2
:
{
ParallelTensorFunctor
(
[
&
](
auto
i0
,
auto
i1
)
{
mData
(
mDesc
.
Get1dIndex
(
i0
,
i1
))
=
g
(
i0
,
i1
);
},
mDesc
.
GetLengths
()[
0
],
mDesc
.
GetLengths
()[
1
])();
break
;
}
case
3
:
{
ParallelTensorFunctor
(
[
&
](
auto
i0
,
auto
i1
,
auto
i2
)
{
mData
(
mDesc
.
Get1dIndex
(
i0
,
i1
,
i2
))
=
g
(
i0
,
i1
,
i2
);
},
mDesc
.
GetLengths
()[
0
],
mDesc
.
GetLengths
()[
1
],
mDesc
.
GetLengths
()[
2
])();
break
;
}
case
4
:
{
ParallelTensorFunctor
(
[
&
](
auto
i0
,
auto
i1
,
auto
i2
,
auto
i3
)
{
mData
(
mDesc
.
Get1dIndex
(
i0
,
i1
,
i2
,
i3
))
=
g
(
i0
,
i1
,
i2
,
i3
);
},
mDesc
.
GetLengths
()[
0
],
mDesc
.
GetLengths
()[
1
],
mDesc
.
GetLengths
()[
3
],
mDesc
.
GetLengths
()[
4
])();
break
;
}
default:
throw
std
::
runtime_error
(
"unspported dimension"
);
}
}
}
T
&
operator
[](
std
::
size_t
i
)
{
return
mData
.
at
(
i
);
}
T
&
operator
[](
std
::
size_t
i
)
{
return
mData
.
at
(
i
);
}
...
@@ -103,42 +152,44 @@ struct Tensor
...
@@ -103,42 +152,44 @@ struct Tensor
struct
GpuMem
struct
GpuMem
{
{
GpuMem
()
=
delete
;
GpuMem
()
=
delete
;
GpuMem
(
std
::
size_t
s
z
,
std
::
size_t
data_s
z
)
:
mS
z
(
sz
),
mDataS
z
(
data_s
z
)
GpuMem
(
std
::
size_t
s
ize
,
std
::
size_t
data_s
ize
)
:
mS
ize
(
size
),
mDataS
ize
(
data_s
ize
)
{
{
cudaMalloc
(
stat
c
i_cast
<
void
**>
(
&
GpuBuf
),
mDataSize
*
mS
z
);
cudaMalloc
(
stati
c
_cast
<
void
**>
(
&
m
GpuBuf
),
mDataSize
*
mS
ize
);
}
}
int
ToGpu
(
void
*
p
)
int
ToGpu
(
void
*
p
)
{
{
return
static_cast
<
int
>
(
cudaMemcpy
(
mGpuBuf
,
p
,
mDataS
z
*
mS
z
,
cudaMem
C
pyHostToDevice
));
return
static_cast
<
int
>
(
cudaMemcpy
(
mGpuBuf
,
p
,
mDataS
ize
*
mS
ize
,
cudaMem
c
pyHostToDevice
));
}
}
int
FromGpu
(
void
*
p
)
{
return
static_cast
<
int
>
(
cuadMemCpy
(
p
,
mGpuBuf
,
mDataSz
*
mSz
));
}
int
FromGpu
(
void
*
p
)
{
return
static_cast
<
int
>
(
cudaMemcpy
(
p
,
mGpuBuf
,
mDataSize
*
mSize
,
cudaMemcpyDeviceToHost
));
}
~
GpuMem
()
{
cudaFree
(
mGpuBuf
);
}
~
GpuMem
()
{
cudaFree
(
mGpuBuf
);
}
void
*
mGpuBuf
;
void
*
mGpuBuf
;
std
::
size_t
mS
z
;
std
::
size_t
mS
ize
;
std
::
size_t
mDataS
z
;
std
::
size_t
mDataS
ize
;
};
};
void
dummy
()
struct
joinable_thread
:
std
::
thread
{
{
auto
f1
=
[](
int
n
,
int
c
,
int
h
,
int
w
)
{
do_f1
(
n
,
c
,
h
,
w
);
};
template
<
class
...
Xs
>
auto
f2
=
[](
int
n
,
int
c
,
int
h
,
int
w
)
{
do_f2
(
n
,
c
,
h
,
w
);
};
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
auto
par_f1
=
generate_ParallelTensorFunctor
(
f1
,
3
,
3
,
3
,
3
,
3
);
}
auto
par_f2
=
generate_ParallelTensorFunctor
(
f2
,
4
,
4
,
4
);
auto
r1
=
par_f1
();
joinable_thread
(
joinable_thread
&&
)
=
default
;
auto
r2
=
par_f2
();
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
}
template
<
class
F
,
class
...
Xs
>
~
joinable_thread
()
auto
generate_parallel_tensor_functor
(
F
f
,
Xs
...
xs
)
{
{
if
(
this
->
joinable
())
return
ParallelTensorFunctor
(
f
,
xs
...);
this
->
join
();
}
}
};
template
<
class
F
,
class
...
Xs
>
template
<
class
F
,
class
...
Xs
>
struct
ParallelTensorFunctor
struct
ParallelTensorFunctor
...
@@ -150,7 +201,7 @@ struct ParallelTensorFunctor
...
@@ -150,7 +201,7 @@ struct ParallelTensorFunctor
};
};
F
mF
;
F
mF
;
constexpr
std
::
size_t
DIM
=
sizeof
...(
Xs
);
static
constexpr
std
::
size_t
N
DIM
=
sizeof
...(
Xs
);
std
::
array
<
std
::
size_t
,
NDIM
>
mLens
;
std
::
array
<
std
::
size_t
,
NDIM
>
mLens
;
std
::
array
<
std
::
size_t
,
NDIM
>
mStrides
;
std
::
array
<
std
::
size_t
,
NDIM
>
mStrides
;
std
::
size_t
mN1d
;
std
::
size_t
mN1d
;
...
@@ -165,16 +216,29 @@ struct ParallelTensorFunctor
...
@@ -165,16 +216,29 @@ struct ParallelTensorFunctor
mN1d
=
mStrides
[
0
]
*
mLens
[
0
];
mN1d
=
mStrides
[
0
]
*
mLens
[
0
];
}
}
std
::
array
<
std
::
size_t
,
NDIM
>
GetNdIndices
(
std
::
size_t
i
)
const
{
std
::
array
<
std
::
size_t
,
NDIM
>
indices
;
for
(
int
idim
=
0
;
idim
<
NDIM
;
++
idim
)
{
indices
[
idim
]
=
i
/
mStrides
[
idim
];
i
-=
indices
[
idim
]
*
mStrides
[
idim
];
}
return
indices
;
}
void
operator
()(
std
::
integral_constant
<
ParallelMethod_t
,
ParallelMethod_t
::
Serial
>
)
void
operator
()(
std
::
integral_constant
<
ParallelMethod_t
,
ParallelMethod_t
::
Serial
>
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
mN1d
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
mN1d
;
++
i
)
{
{
call_f_unpack_
indice
s
(
mF
,
GetNdIndices
(
i
));
call_f_unpack_
arg
s
(
mF
,
GetNdIndices
(
i
));
}
}
}
}
void
operator
()(
std
::
integral_constant
<
ParallelMethod_t
,
ParallelMethod_t
::
Parallel
>
,
void
operator
()(
std
::
integral_constant
<
ParallelMethod_t
,
ParallelMethod_t
::
Parallel
>
,
std
::
size_t
::
num_thread
)
std
::
size_t
num_thread
)
{
{
std
::
size_t
work_per_thread
=
(
mN1d
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
work_per_thread
=
(
mN1d
+
num_thread
-
1
)
/
num_thread
;
...
@@ -183,42 +247,43 @@ struct ParallelTensorFunctor
...
@@ -183,42 +247,43 @@ struct ParallelTensorFunctor
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
(((
it
+
1
)
*
work_per_thread
,
mN1d
));
std
::
size_t
iw_end
=
std
::
min
(((
it
+
1
)
*
work_per_thread
,
mN1d
));
auto
f
=
[
=
]
{
auto
f
=
[
=
]
{
for
(
std
::
size_t
iw
=
iw_begin
;
iw
<
iw_end
;
++
iw
)
for
(
std
::
size_t
iw
=
iw_begin
;
iw
<
iw_end
;
++
iw
)
call_f_unpack_indices
(
mF
,
GetNdIndices
(
iw
);
{
call_f_unpack_args
(
mF
,
GetNdIndices
(
iw
));
}
};
};
threads
[
it
]
=
joinable_thread
(
f
);
threads
[
it
]
=
joinable_thread
(
f
);
}
}
}
}
};
};
struct
joinable_thread
:
std
::
thread
template
<
class
F
,
class
T
>
auto
call_f_unpack_args
(
F
f
,
T
args
)
{
{
template
<
class
...
Xs
>
static
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>::
value
;
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
~
joinable_thread
()
return
call_f_unpack_args_impl
(
f
,
args
,
std
::
make_index_sequence
<
N
>
{});
{
if
(
this
->
joinable
())
this
->
join
;
}
}
}
template
<
class
F
,
class
T
>
template
<
class
F
,
class
T
,
class
...
Is
>
auto
call_f_unpack_
indices
(
F
f
,
T
indices
)
auto
call_f_unpack_
args_impl
(
F
f
,
T
args
,
std
::
integer_sequence
<
Is
...
>
)
{
{
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>::
value
;
return
f
(
std
::
get
<
Is
>
(
args
)...);
using
NSeq
=
std
::
make_integer_sequence
<
std
::
size_t
,
N
>
;
return
call_f_unpack_indices_impl
(
f
,
indices
,
NSeq
{});
}
}
template
<
class
F
,
class
T
,
class
...
Is
>
template
<
class
F
,
class
T
,
class
...
Is
>
auto
c
all
_f_unpack_
indice
s_impl
(
F
f
,
T
indice
s
,
std
::
integer_sequence
<
std
::
size_t
,
Is
...
>
)
auto
c
onstruct
_f_unpack_
arg
s_impl
(
T
arg
s
,
std
::
integer_sequence
<
Is
...
>
)
{
{
return
f
(
std
::
get
<
Is
>
(
indices
)...);
return
F
(
std
::
get
<
Is
>
(
args
)...);
}
template
<
class
F
,
class
T
>
auto
construct_f_unpack_args
(
F
,
T
args
)
{
static
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>::
value
;
return
construct_f_unpack_args_impl
<
F
>
(
args
,
std
::
make_index_sequence
<
N
>
{});
}
}
src/tensor.cpp
View file @
06c9f9fe
...
@@ -3,8 +3,6 @@
...
@@ -3,8 +3,6 @@
#include "tensor.hpp"
#include "tensor.hpp"
TensorDescriptor
::
TensorDescriptor
()
{}
TensorDescriptor
::
TensorDescriptor
(
DataType_t
t
,
std
::
initializer_list
<
std
::
size_t
>
lens
)
TensorDescriptor
::
TensorDescriptor
(
DataType_t
t
,
std
::
initializer_list
<
std
::
size_t
>
lens
)
:
mLens
(
lens
),
mDataType
(
t
)
:
mLens
(
lens
),
mDataType
(
t
)
{
{
...
@@ -22,7 +20,7 @@ void TensorDescriptor::CalculateStrides()
...
@@ -22,7 +20,7 @@ void TensorDescriptor::CalculateStrides()
{
{
mStrides
.
clear
();
mStrides
.
clear
();
mStrides
.
resize
(
mLens
.
size
(),
0
);
mStrides
.
resize
(
mLens
.
size
(),
0
);
if
(
s
trides
.
empty
())
if
(
mS
trides
.
empty
())
return
;
return
;
mStrides
.
back
()
=
1
;
mStrides
.
back
()
=
1
;
...
@@ -41,6 +39,10 @@ std::size_t TensorDescriptor::GetElementSize() const
...
@@ -41,6 +39,10 @@ std::size_t TensorDescriptor::GetElementSize() const
std
::
size_t
TensorDescriptor
::
GetElementSpace
()
const
std
::
size_t
TensorDescriptor
::
GetElementSpace
()
const
{
{
auto
ls
=
mLens
|
boost
::
adaptor
::
transformed
([](
auto
v
)
{
return
v
-
1
;
});
auto
ls
=
mLens
|
boost
::
adaptor
s
::
transformed
([](
std
::
size_t
v
)
{
return
v
-
1
;
});
return
std
::
inner_product
(
ls
.
begin
(),
ls
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
})
+
1
;
return
std
::
inner_product
(
ls
.
begin
(),
ls
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
})
+
1
;
}
}
const
std
::
vector
<
std
::
size_t
>&
TensorDescriptor
::
GetLengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
TensorDescriptor
::
GetStrides
()
const
{
return
mStrides
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment