"tests/python/common/test_batch-graph.py" did not exist on "605b51857ea2380da76816cdf56f64e69df30dba"
Commit 1de47514 authored by Davis King's avatar Davis King
Browse files

Make input_layer() work with networks that contain repeat layers.

Do this by just making all layers have a .input_layer() method, which in
that context can be implemented in a simple manner.
parent ded68b9a
...@@ -652,7 +652,10 @@ namespace dlib ...@@ -652,7 +652,10 @@ namespace dlib
// Not much here because in this case T is one of the input layer types // Not much here because in this case T is one of the input layer types
// that doesn't have anything in it. // that doesn't have anything in it.
typedef T layer_details_type; typedef T layer_details_type;
typedef T input_layer_type;
const layer_details_type& layer_details() const { return l; } const layer_details_type& layer_details() const { return l; }
const input_layer_type& input_layer() const { return l; }
input_layer_type& input_layer() { return l; }
unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
private: private:
T& l; T& l;
...@@ -671,6 +674,7 @@ namespace dlib ...@@ -671,6 +674,7 @@ namespace dlib
const static size_t num_computational_layers = T::num_computational_layers; const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers; const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type; typedef typename T::layer_details_type layer_details_type;
typedef typename T::input_layer_type input_layer_type;
subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {} subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {}
...@@ -683,6 +687,9 @@ namespace dlib ...@@ -683,6 +687,9 @@ namespace dlib
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; } subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); }
const input_layer_type& input_layer() const { return l.input_layer(); }
input_layer_type& input_layer() { return l.input_layer(); }
private: private:
T& l; T& l;
subnet_wrapper<typename T::subnet_type,false> subnetwork; subnet_wrapper<typename T::subnet_type,false> subnetwork;
...@@ -700,6 +707,7 @@ namespace dlib ...@@ -700,6 +707,7 @@ namespace dlib
const static size_t num_computational_layers = T::num_computational_layers; const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers; const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type; typedef typename T::layer_details_type layer_details_type;
typedef typename T::input_layer_type input_layer_type;
subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {} subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {}
...@@ -712,6 +720,9 @@ namespace dlib ...@@ -712,6 +720,9 @@ namespace dlib
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; } subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); }
const input_layer_type& input_layer() const { return l.input_layer(); }
input_layer_type& input_layer() { return l.input_layer(); }
private: private:
T& l; T& l;
subnet_wrapper<typename T::subnet_type,false> subnetwork; subnet_wrapper<typename T::subnet_type,false> subnetwork;
...@@ -738,6 +749,7 @@ namespace dlib ...@@ -738,6 +749,7 @@ namespace dlib
public: public:
typedef LAYER_DETAILS layer_details_type; typedef LAYER_DETAILS layer_details_type;
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers + 1; const static size_t num_computational_layers = subnet_type::num_computational_layers + 1;
...@@ -1037,6 +1049,9 @@ namespace dlib ...@@ -1037,6 +1049,9 @@ namespace dlib
const subnet_type& subnet() const { return *subnetwork; } const subnet_type& subnet() const { return *subnetwork; }
subnet_type& subnet() { return *subnetwork; } subnet_type& subnet() { return *subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
const layer_details_type& layer_details() const { return details; } const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; } layer_details_type& layer_details() { return details; }
...@@ -1164,6 +1179,7 @@ namespace dlib ...@@ -1164,6 +1179,7 @@ namespace dlib
public: public:
typedef LAYER_DETAILS layer_details_type; typedef LAYER_DETAILS layer_details_type;
typedef INPUT_LAYER subnet_type; typedef INPUT_LAYER subnet_type;
typedef INPUT_LAYER input_layer_type;
typedef typename INPUT_LAYER::input_type input_type; typedef typename INPUT_LAYER::input_type input_type;
const static size_t num_layers = 2; const static size_t num_layers = 2;
const static size_t num_computational_layers = 1; const static size_t num_computational_layers = 1;
...@@ -1198,7 +1214,7 @@ namespace dlib ...@@ -1198,7 +1214,7 @@ namespace dlib
add_layer( add_layer(
const add_layer<T,U,E>& item const add_layer<T,U,E>& item
): ):
input_layer(item.subnet()), input_layer_(item.subnet()),
details(item.layer_details()), details(item.layer_details()),
this_layer_setup_called(item.this_layer_setup_called), this_layer_setup_called(item.this_layer_setup_called),
gradient_input_is_stale(item.gradient_input_is_stale), gradient_input_is_stale(item.gradient_input_is_stale),
...@@ -1223,7 +1239,7 @@ namespace dlib ...@@ -1223,7 +1239,7 @@ namespace dlib
add_layer( add_layer(
const INPUT_LAYER& il const INPUT_LAYER& il
) : ) :
input_layer(il), input_layer_(il),
this_layer_setup_called(false), this_layer_setup_called(false),
gradient_input_is_stale(true), gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false), get_output_and_gradient_input_disabled(false),
...@@ -1245,7 +1261,7 @@ namespace dlib ...@@ -1245,7 +1261,7 @@ namespace dlib
INPUT_LAYER il INPUT_LAYER il
) : ) :
details(std::move(layer_det)), details(std::move(layer_det)),
input_layer(std::move(il)), input_layer_(std::move(il)),
this_layer_setup_called(false), this_layer_setup_called(false),
gradient_input_is_stale(true), gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false), get_output_and_gradient_input_disabled(false),
...@@ -1284,7 +1300,7 @@ namespace dlib ...@@ -1284,7 +1300,7 @@ namespace dlib
resizable_tensor& data resizable_tensor& data
) const ) const
{ {
input_layer.to_tensor(ibegin, iend, data); input_layer_.to_tensor(ibegin, iend, data);
// make sure the input layer's to_tensor() function is implemented properly. // make sure the input layer's to_tensor() function is implemented properly.
DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend), DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend),
"The input layer can't produce fewer output tensors than there are inputs."); "The input layer can't produce fewer output tensors than there are inputs.");
...@@ -1403,8 +1419,11 @@ namespace dlib ...@@ -1403,8 +1419,11 @@ namespace dlib
tensor& get_parameter_gradient ( tensor& get_parameter_gradient (
) { return params_grad; } ) { return params_grad; }
const subnet_type& subnet() const { return input_layer; } const subnet_type& subnet() const { return input_layer_; }
subnet_type& subnet() { return input_layer; } subnet_type& subnet() { return input_layer_; }
const subnet_type& input_layer() const { return input_layer_; }
subnet_type& input_layer() { return input_layer_; }
const layer_details_type& layer_details() const { return details; } const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; } layer_details_type& layer_details() { return details; }
...@@ -1426,7 +1445,7 @@ namespace dlib ...@@ -1426,7 +1445,7 @@ namespace dlib
{ {
int version = 3; int version = 3;
serialize(version, out); serialize(version, out);
serialize(item.input_layer, out); serialize(item.input_layer_, out);
serialize(item.details, out); serialize(item.details, out);
serialize(item.this_layer_setup_called, out); serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out); serialize(item.gradient_input_is_stale, out);
...@@ -1443,7 +1462,7 @@ namespace dlib ...@@ -1443,7 +1462,7 @@ namespace dlib
deserialize(version, in); deserialize(version, in);
if (!(2 <= version && version <= 3)) if (!(2 <= version && version <= 3))
throw serialization_error("Unexpected version found while deserializing dlib::add_layer."); throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.input_layer, in); deserialize(item.input_layer_, in);
deserialize(item.details, in); deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in); deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in); deserialize(item.gradient_input_is_stale, in);
...@@ -1512,7 +1531,7 @@ namespace dlib ...@@ -1512,7 +1531,7 @@ namespace dlib
void swap(add_layer& item) void swap(add_layer& item)
{ {
std::swap(input_layer, item.input_layer); std::swap(input_layer_, item.input_layer_);
std::swap(details, item.details); std::swap(details, item.details);
std::swap(this_layer_setup_called, item.this_layer_setup_called); std::swap(this_layer_setup_called, item.this_layer_setup_called);
std::swap(gradient_input_is_stale, item.gradient_input_is_stale); std::swap(gradient_input_is_stale, item.gradient_input_is_stale);
...@@ -1523,7 +1542,7 @@ namespace dlib ...@@ -1523,7 +1542,7 @@ namespace dlib
std::swap(_sample_expansion_factor, item._sample_expansion_factor); std::swap(_sample_expansion_factor, item._sample_expansion_factor);
} }
subnet_type input_layer; subnet_type input_layer_;
LAYER_DETAILS details; LAYER_DETAILS details;
bool this_layer_setup_called; bool this_layer_setup_called;
bool gradient_input_is_stale; bool gradient_input_is_stale;
...@@ -1558,6 +1577,7 @@ namespace dlib ...@@ -1558,6 +1577,7 @@ namespace dlib
public: public:
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers; const static size_t num_computational_layers = subnet_type::num_computational_layers;
...@@ -1652,6 +1672,9 @@ namespace dlib ...@@ -1652,6 +1672,9 @@ namespace dlib
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; } subnet_type& subnet() { return subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void clean() void clean()
...@@ -1760,6 +1783,7 @@ namespace dlib ...@@ -1760,6 +1783,7 @@ namespace dlib
public: public:
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type; typedef typename SUBNET::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t comp_layers_in_each_group = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers); const static size_t comp_layers_in_each_group = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers);
const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num; const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num;
...@@ -1951,6 +1975,9 @@ namespace dlib ...@@ -1951,6 +1975,9 @@ namespace dlib
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; } subnet_type& subnet() { return subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void clean() void clean()
...@@ -2047,6 +2074,7 @@ namespace dlib ...@@ -2047,6 +2074,7 @@ namespace dlib
public: public:
typedef INPUT_LAYER subnet_type; typedef INPUT_LAYER subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
typedef INPUT_LAYER input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_computational_layers = 0; const static size_t num_computational_layers = 0;
const static size_t num_layers = 2; const static size_t num_layers = 2;
...@@ -2062,7 +2090,7 @@ namespace dlib ...@@ -2062,7 +2090,7 @@ namespace dlib
template <typename T, typename E> template <typename T, typename E>
add_tag_layer( add_tag_layer(
const add_tag_layer<ID,T,E>& item const add_tag_layer<ID,T,E>& item
) : input_layer(item.subnet()), ) : input_layer_(item.subnet()),
cached_output(item.cached_output), cached_output(item.cached_output),
cached_output_ptr(nullptr), cached_output_ptr(nullptr),
grad_final(item.grad_final), grad_final(item.grad_final),
...@@ -2074,7 +2102,7 @@ namespace dlib ...@@ -2074,7 +2102,7 @@ namespace dlib
add_tag_layer( add_tag_layer(
T ...args T ...args
) : ) :
input_layer(std::move(args)...), input_layer_(std::move(args)...),
cached_output_ptr(nullptr), cached_output_ptr(nullptr),
gradient_input_is_stale(true), gradient_input_is_stale(true),
_sample_expansion_factor(0) _sample_expansion_factor(0)
...@@ -2096,7 +2124,7 @@ namespace dlib ...@@ -2096,7 +2124,7 @@ namespace dlib
resizable_tensor& data resizable_tensor& data
) const ) const
{ {
input_layer.to_tensor(ibegin,iend,data); input_layer_.to_tensor(ibegin,iend,data);
// make sure the input layer's to_tensor() function is implemented properly. // make sure the input layer's to_tensor() function is implemented properly.
DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend), DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend),
...@@ -2116,7 +2144,7 @@ namespace dlib ...@@ -2116,7 +2144,7 @@ namespace dlib
forward_iterator iend forward_iterator iend
) )
{ {
input_layer.to_tensor(ibegin,iend,cached_output); input_layer_.to_tensor(ibegin,iend,cached_output);
cached_output_ptr = nullptr; cached_output_ptr = nullptr;
return get_output(); return get_output();
} }
...@@ -2184,8 +2212,11 @@ namespace dlib ...@@ -2184,8 +2212,11 @@ namespace dlib
update_parameters(make_sstack(solvers), learning_rate); update_parameters(make_sstack(solvers), learning_rate);
} }
const subnet_type& subnet() const { return input_layer; } const subnet_type& subnet() const { return input_layer_; }
subnet_type& subnet() { return input_layer; } subnet_type& subnet() { return input_layer_; }
const input_layer_type& input_layer() const { return input_layer_; }
input_layer_type& input_layer() { return input_layer_; }
void clean() void clean()
{ {
...@@ -2198,7 +2229,7 @@ namespace dlib ...@@ -2198,7 +2229,7 @@ namespace dlib
{ {
int version = 2; int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.input_layer, out); serialize(item.input_layer_, out);
serialize(item.cached_output, out); serialize(item.cached_output, out);
serialize(item.grad_final, out); serialize(item.grad_final, out);
serialize(item.gradient_input_is_stale, out); serialize(item.gradient_input_is_stale, out);
...@@ -2211,7 +2242,7 @@ namespace dlib ...@@ -2211,7 +2242,7 @@ namespace dlib
deserialize(version, in); deserialize(version, in);
if (!(1 <= version && version <= 2)) if (!(1 <= version && version <= 2))
throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer."); throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
deserialize(item.input_layer, in); deserialize(item.input_layer_, in);
deserialize(item.cached_output, in); deserialize(item.cached_output, in);
deserialize(item.grad_final, in); deserialize(item.grad_final, in);
deserialize(item.gradient_input_is_stale, in); deserialize(item.gradient_input_is_stale, in);
...@@ -2275,7 +2306,7 @@ namespace dlib ...@@ -2275,7 +2306,7 @@ namespace dlib
void swap(add_tag_layer& item) void swap(add_tag_layer& item)
{ {
std::swap(input_layer, item.input_layer); std::swap(input_layer_, item.input_layer_);
std::swap(cached_output, item.cached_output); std::swap(cached_output, item.cached_output);
std::swap(cached_output_ptr, item.cached_output_ptr); std::swap(cached_output_ptr, item.cached_output_ptr);
std::swap(grad_final, item.grad_final); std::swap(grad_final, item.grad_final);
...@@ -2283,7 +2314,7 @@ namespace dlib ...@@ -2283,7 +2314,7 @@ namespace dlib
std::swap(_sample_expansion_factor, item._sample_expansion_factor); std::swap(_sample_expansion_factor, item._sample_expansion_factor);
} }
subnet_type input_layer; subnet_type input_layer_;
resizable_tensor cached_output; resizable_tensor cached_output;
tensor* cached_output_ptr; tensor* cached_output_ptr;
resizable_tensor grad_final; resizable_tensor grad_final;
...@@ -2348,6 +2379,7 @@ namespace dlib ...@@ -2348,6 +2379,7 @@ namespace dlib
typedef LOSS_DETAILS loss_details_type; typedef LOSS_DETAILS loss_details_type;
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_layers = subnet_type::num_layers + 1;
// Note that the loss layer doesn't count as an additional computational layer. // Note that the loss layer doesn't count as an additional computational layer.
const static size_t num_computational_layers = subnet_type::num_computational_layers; const static size_t num_computational_layers = subnet_type::num_computational_layers;
...@@ -2628,6 +2660,10 @@ namespace dlib ...@@ -2628,6 +2660,10 @@ namespace dlib
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; } subnet_type& subnet() { return subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
const loss_details_type& loss_details() const { return loss; } const loss_details_type& loss_details() const { return loss; }
loss_details_type& loss_details() { return loss; } loss_details_type& loss_details() { return loss; }
...@@ -2888,43 +2924,20 @@ namespace dlib ...@@ -2888,43 +2924,20 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename net_type>
namespace dimpl typename net_type::input_layer_type& input_layer (
net_type& net
)
{ {
template <typename T> return net.input_layer();
T& get_input_details (
T& net
)
{
return net;
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
const dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
} }
template <typename net_type> template <typename net_type>
auto input_layer ( const typename net_type::input_layer_type& input_layer (
net_type& net const net_type& net
) -> decltype(dimpl::get_input_details(layer<net_type::num_layers-1>(net)))& )
{ {
// Calling input_layer() on a subnet_wrapper is a little funny since the behavior of return net.input_layer();
// .subnet() returns another subnet_wrapper rather than an input details object as it
// does in add_layer.
return dimpl::get_input_details(layer<net_type::num_layers-1>(net));
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -2935,6 +2948,7 @@ namespace dlib ...@@ -2935,6 +2948,7 @@ namespace dlib
public: public:
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers; const static size_t num_computational_layers = subnet_type::num_computational_layers;
...@@ -3042,6 +3056,9 @@ namespace dlib ...@@ -3042,6 +3056,9 @@ namespace dlib
return subnetwork; return subnetwork;
} }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void clean() void clean()
......
...@@ -308,6 +308,7 @@ namespace dlib ...@@ -308,6 +308,7 @@ namespace dlib
typedef LAYER_DETAILS layer_details_type; typedef LAYER_DETAILS layer_details_type;
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
// num_computational_layers will always give the number of layers in the network // num_computational_layers will always give the number of layers in the network
// that transform tensors (i.e. layers defined by something that implements the // that transform tensors (i.e. layers defined by something that implements the
// EXAMPLE_COMPUTATIONAL_LAYER_ interface). This is all the layers except for // EXAMPLE_COMPUTATIONAL_LAYER_ interface). This is all the layers except for
...@@ -446,6 +447,26 @@ namespace dlib ...@@ -446,6 +447,26 @@ namespace dlib
- returns the immediate subnetwork of *this network. - returns the immediate subnetwork of *this network.
!*/ !*/
const input_layer_type& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
input_layer_type& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
const layer_details_type& layer_details( const layer_details_type& layer_details(
) const; ) const;
/*! /*!
...@@ -730,6 +751,7 @@ namespace dlib ...@@ -730,6 +751,7 @@ namespace dlib
typedef LOSS_DETAILS loss_details_type; typedef LOSS_DETAILS loss_details_type;
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
const static size_t num_computational_layers = subnet_type::num_computational_layers; const static size_t num_computational_layers = subnet_type::num_computational_layers;
const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_layers = subnet_type::num_layers + 1;
// If LOSS_DETAILS is an unsupervised loss then training_label_type==no_label_type. // If LOSS_DETAILS is an unsupervised loss then training_label_type==no_label_type.
...@@ -818,6 +840,26 @@ namespace dlib ...@@ -818,6 +840,26 @@ namespace dlib
- returns the immediate subnetwork of *this network. - returns the immediate subnetwork of *this network.
!*/ !*/
const input_layer_type& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
input_layer_type& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
const loss_details_type& loss_details( const loss_details_type& loss_details(
) const; ) const;
/*! /*!
...@@ -1357,6 +1399,7 @@ namespace dlib ...@@ -1357,6 +1399,7 @@ namespace dlib
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type; typedef typename SUBNET::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
const static size_t num_computational_layers = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers)*num + SUBNET::num_computational_layers; const static size_t num_computational_layers = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers)*num + SUBNET::num_computational_layers;
const static size_t num_layers = (REPEATED_LAYER<SUBNET>::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers; const static size_t num_layers = (REPEATED_LAYER<SUBNET>::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers;
typedef REPEATED_LAYER<an_unspecified_input_type> repeated_layer_type; typedef REPEATED_LAYER<an_unspecified_input_type> repeated_layer_type;
...@@ -1437,6 +1480,27 @@ namespace dlib ...@@ -1437,6 +1480,27 @@ namespace dlib
- returns the SUBNET base network that repeat sits on top of. If you want - returns the SUBNET base network that repeat sits on top of. If you want
to access the REPEATED_LAYER components then you must use get_repeated_layer(). to access the REPEATED_LAYER components then you must use get_repeated_layer().
!*/ !*/
const input_layer_type& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
input_layer_type& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
}; };
template < size_t num, template<typename> class T, typename U > template < size_t num, template<typename> class T, typename U >
...@@ -1651,13 +1715,11 @@ namespace dlib ...@@ -1651,13 +1715,11 @@ namespace dlib
); );
/*! /*!
requires requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, repeat, or
add_tag_layer. add_tag_layer.
ensures ensures
- returns the input later of the given network object. Specifically, this - returns the input later of the given network object. This is the same as just calling
function is equivalent to calling: net.input_layer().
layer<net_type::num_layers-1>(net);
That is, you get the input layer details object for the network.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -83,6 +83,26 @@ namespace dlib ...@@ -83,6 +83,26 @@ namespace dlib
begins with layer2. begins with layer2.
!*/ !*/
const INPUT_LAYER& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
INPUT_LAYER& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
const layer_details_type& layer_details( const layer_details_type& layer_details(
) const; ) const;
/*! /*!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment