Commit 1de47514 authored by Davis King's avatar Davis King
Browse files

Make input_layer() work with networks that contain repeat layers.

Do this by just making all layers have a .input_layer() method, which in
that context can be implemented in a simple manner.
parent ded68b9a
......@@ -652,7 +652,10 @@ namespace dlib
// Not much here because in this case T is one of the input layer types
// that doesn't have anything in it.
typedef T layer_details_type;
typedef T input_layer_type;
const layer_details_type& layer_details() const { return l; }
const input_layer_type& input_layer() const { return l; }
input_layer_type& input_layer() { return l; }
unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
private:
T& l;
......@@ -671,6 +674,7 @@ namespace dlib
const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type;
typedef typename T::input_layer_type input_layer_type;
subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {}
......@@ -683,6 +687,9 @@ namespace dlib
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); }
const input_layer_type& input_layer() const { return l.input_layer(); }
input_layer_type& input_layer() { return l.input_layer(); }
private:
T& l;
subnet_wrapper<typename T::subnet_type,false> subnetwork;
......@@ -700,6 +707,7 @@ namespace dlib
const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type;
typedef typename T::input_layer_type input_layer_type;
subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {}
......@@ -712,6 +720,9 @@ namespace dlib
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); }
const input_layer_type& input_layer() const { return l.input_layer(); }
input_layer_type& input_layer() { return l.input_layer(); }
private:
T& l;
subnet_wrapper<typename T::subnet_type,false> subnetwork;
......@@ -738,6 +749,7 @@ namespace dlib
public:
typedef LAYER_DETAILS layer_details_type;
typedef SUBNET subnet_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef typename subnet_type::input_type input_type;
const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers + 1;
......@@ -1037,6 +1049,9 @@ namespace dlib
const subnet_type& subnet() const { return *subnetwork; }
subnet_type& subnet() { return *subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; }
......@@ -1164,6 +1179,7 @@ namespace dlib
public:
typedef LAYER_DETAILS layer_details_type;
typedef INPUT_LAYER subnet_type;
typedef INPUT_LAYER input_layer_type;
typedef typename INPUT_LAYER::input_type input_type;
const static size_t num_layers = 2;
const static size_t num_computational_layers = 1;
......@@ -1198,7 +1214,7 @@ namespace dlib
add_layer(
const add_layer<T,U,E>& item
):
input_layer(item.subnet()),
input_layer_(item.subnet()),
details(item.layer_details()),
this_layer_setup_called(item.this_layer_setup_called),
gradient_input_is_stale(item.gradient_input_is_stale),
......@@ -1223,7 +1239,7 @@ namespace dlib
add_layer(
const INPUT_LAYER& il
) :
input_layer(il),
input_layer_(il),
this_layer_setup_called(false),
gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false),
......@@ -1245,7 +1261,7 @@ namespace dlib
INPUT_LAYER il
) :
details(std::move(layer_det)),
input_layer(std::move(il)),
input_layer_(std::move(il)),
this_layer_setup_called(false),
gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false),
......@@ -1284,7 +1300,7 @@ namespace dlib
resizable_tensor& data
) const
{
input_layer.to_tensor(ibegin, iend, data);
input_layer_.to_tensor(ibegin, iend, data);
// make sure the input layer's to_tensor() function is implemented properly.
DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend),
"The input layer can't produce fewer output tensors than there are inputs.");
......@@ -1403,8 +1419,11 @@ namespace dlib
tensor& get_parameter_gradient (
) { return params_grad; }
const subnet_type& subnet() const { return input_layer; }
subnet_type& subnet() { return input_layer; }
const subnet_type& subnet() const { return input_layer_; }
subnet_type& subnet() { return input_layer_; }
const subnet_type& input_layer() const { return input_layer_; }
subnet_type& input_layer() { return input_layer_; }
const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; }
......@@ -1426,7 +1445,7 @@ namespace dlib
{
int version = 3;
serialize(version, out);
serialize(item.input_layer, out);
serialize(item.input_layer_, out);
serialize(item.details, out);
serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out);
......@@ -1443,7 +1462,7 @@ namespace dlib
deserialize(version, in);
if (!(2 <= version && version <= 3))
throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.input_layer, in);
deserialize(item.input_layer_, in);
deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in);
......@@ -1512,7 +1531,7 @@ namespace dlib
void swap(add_layer& item)
{
std::swap(input_layer, item.input_layer);
std::swap(input_layer_, item.input_layer_);
std::swap(details, item.details);
std::swap(this_layer_setup_called, item.this_layer_setup_called);
std::swap(gradient_input_is_stale, item.gradient_input_is_stale);
......@@ -1523,7 +1542,7 @@ namespace dlib
std::swap(_sample_expansion_factor, item._sample_expansion_factor);
}
subnet_type input_layer;
subnet_type input_layer_;
LAYER_DETAILS details;
bool this_layer_setup_called;
bool gradient_input_is_stale;
......@@ -1558,6 +1577,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers;
......@@ -1652,6 +1672,9 @@ namespace dlib
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void clean()
......@@ -1760,6 +1783,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t comp_layers_in_each_group = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers);
const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num;
......@@ -1951,6 +1975,9 @@ namespace dlib
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void clean()
......@@ -2047,6 +2074,7 @@ namespace dlib
public:
typedef INPUT_LAYER subnet_type;
typedef typename subnet_type::input_type input_type;
typedef INPUT_LAYER input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_computational_layers = 0;
const static size_t num_layers = 2;
......@@ -2062,7 +2090,7 @@ namespace dlib
template <typename T, typename E>
add_tag_layer(
const add_tag_layer<ID,T,E>& item
) : input_layer(item.subnet()),
) : input_layer_(item.subnet()),
cached_output(item.cached_output),
cached_output_ptr(nullptr),
grad_final(item.grad_final),
......@@ -2074,7 +2102,7 @@ namespace dlib
add_tag_layer(
T ...args
) :
input_layer(std::move(args)...),
input_layer_(std::move(args)...),
cached_output_ptr(nullptr),
gradient_input_is_stale(true),
_sample_expansion_factor(0)
......@@ -2096,7 +2124,7 @@ namespace dlib
resizable_tensor& data
) const
{
input_layer.to_tensor(ibegin,iend,data);
input_layer_.to_tensor(ibegin,iend,data);
// make sure the input layer's to_tensor() function is implemented properly.
DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend),
......@@ -2116,7 +2144,7 @@ namespace dlib
forward_iterator iend
)
{
input_layer.to_tensor(ibegin,iend,cached_output);
input_layer_.to_tensor(ibegin,iend,cached_output);
cached_output_ptr = nullptr;
return get_output();
}
......@@ -2184,8 +2212,11 @@ namespace dlib
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return input_layer; }
subnet_type& subnet() { return input_layer; }
const subnet_type& subnet() const { return input_layer_; }
subnet_type& subnet() { return input_layer_; }
const input_layer_type& input_layer() const { return input_layer_; }
input_layer_type& input_layer() { return input_layer_; }
void clean()
{
......@@ -2198,7 +2229,7 @@ namespace dlib
{
int version = 2;
serialize(version, out);
serialize(item.input_layer, out);
serialize(item.input_layer_, out);
serialize(item.cached_output, out);
serialize(item.grad_final, out);
serialize(item.gradient_input_is_stale, out);
......@@ -2211,7 +2242,7 @@ namespace dlib
deserialize(version, in);
if (!(1 <= version && version <= 2))
throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
deserialize(item.input_layer, in);
deserialize(item.input_layer_, in);
deserialize(item.cached_output, in);
deserialize(item.grad_final, in);
deserialize(item.gradient_input_is_stale, in);
......@@ -2275,7 +2306,7 @@ namespace dlib
void swap(add_tag_layer& item)
{
std::swap(input_layer, item.input_layer);
std::swap(input_layer_, item.input_layer_);
std::swap(cached_output, item.cached_output);
std::swap(cached_output_ptr, item.cached_output_ptr);
std::swap(grad_final, item.grad_final);
......@@ -2283,7 +2314,7 @@ namespace dlib
std::swap(_sample_expansion_factor, item._sample_expansion_factor);
}
subnet_type input_layer;
subnet_type input_layer_;
resizable_tensor cached_output;
tensor* cached_output_ptr;
resizable_tensor grad_final;
......@@ -2348,6 +2379,7 @@ namespace dlib
typedef LOSS_DETAILS loss_details_type;
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
const static size_t num_layers = subnet_type::num_layers + 1;
// Note that the loss layer doesn't count as an additional computational layer.
const static size_t num_computational_layers = subnet_type::num_computational_layers;
......@@ -2628,6 +2660,10 @@ namespace dlib
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
const loss_details_type& loss_details() const { return loss; }
loss_details_type& loss_details() { return loss; }
......@@ -2888,43 +2924,20 @@ namespace dlib
// ----------------------------------------------------------------------------------------
namespace dimpl
template <typename net_type>
typename net_type::input_layer_type& input_layer (
net_type& net
)
{
template <typename T>
T& get_input_details (
T& net
)
{
return net;
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
const dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
return net.input_layer();
}
template <typename net_type>
auto input_layer (
net_type& net
) -> decltype(dimpl::get_input_details(layer<net_type::num_layers-1>(net)))&
const typename net_type::input_layer_type& input_layer (
const net_type& net
)
{
// Calling input_layer() on a subnet_wrapper is a little funny since the behavior of
// .subnet() returns another subnet_wrapper rather than an input details object as it
// does in add_layer.
return dimpl::get_input_details(layer<net_type::num_layers-1>(net));
return net.input_layer();
}
// ----------------------------------------------------------------------------------------
......@@ -2935,6 +2948,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers;
......@@ -3042,6 +3056,9 @@ namespace dlib
return subnetwork;
}
const input_layer_type& input_layer() const { return subnet().input_layer(); }
input_layer_type& input_layer() { return subnet().input_layer(); }
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void clean()
......
......@@ -308,6 +308,7 @@ namespace dlib
typedef LAYER_DETAILS layer_details_type;
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
// num_computational_layers will always give the number of layers in the network
// that transform tensors (i.e. layers defined by something that implements the
// EXAMPLE_COMPUTATIONAL_LAYER_ interface). This is all the layers except for
......@@ -446,6 +447,26 @@ namespace dlib
- returns the immediate subnetwork of *this network.
!*/
const input_layer_type& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
input_layer_type& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
const layer_details_type& layer_details(
) const;
/*!
......@@ -730,6 +751,7 @@ namespace dlib
typedef LOSS_DETAILS loss_details_type;
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
const static size_t num_computational_layers = subnet_type::num_computational_layers;
const static size_t num_layers = subnet_type::num_layers + 1;
// If LOSS_DETAILS is an unsupervised loss then training_label_type==no_label_type.
......@@ -818,6 +840,26 @@ namespace dlib
- returns the immediate subnetwork of *this network.
!*/
const input_layer_type& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
input_layer_type& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
const loss_details_type& loss_details(
) const;
/*!
......@@ -1357,6 +1399,7 @@ namespace dlib
typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type;
typedef typename subnet_type::input_layer_type input_layer_type;
const static size_t num_computational_layers = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers)*num + SUBNET::num_computational_layers;
const static size_t num_layers = (REPEATED_LAYER<SUBNET>::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers;
typedef REPEATED_LAYER<an_unspecified_input_type> repeated_layer_type;
......@@ -1437,6 +1480,27 @@ namespace dlib
- returns the SUBNET base network that repeat sits on top of. If you want
to access the REPEATED_LAYER components then you must use get_repeated_layer().
!*/
const input_layer_type& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
input_layer_type& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
};
template < size_t num, template<typename> class T, typename U >
......@@ -1651,13 +1715,11 @@ namespace dlib
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, repeat, or
add_tag_layer.
ensures
- returns the input later of the given network object. Specifically, this
function is equivalent to calling:
layer<net_type::num_layers-1>(net);
That is, you get the input layer details object for the network.
- returns the input later of the given network object. This is the same as just calling
net.input_layer().
!*/
// ----------------------------------------------------------------------------------------
......
......@@ -83,6 +83,26 @@ namespace dlib
begins with layer2.
!*/
const INPUT_LAYER& input_layer(
) const;
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
INPUT_LAYER& input_layer(
);
/*!
ensures
- returns the very first layer in *this network. It's equivalent to calling
subnet() recursively until you get to the first layer. This means it will return
the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined
in input_abstract.h
!*/
const layer_details_type& layer_details(
) const;
/*!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment