Commit 1db057d6 authored by Davis King's avatar Davis King
Browse files

Refactored the code in the http server so that it will be more reusable

by other tools.
parent 05b1ba8b
...@@ -25,103 +25,94 @@ ...@@ -25,103 +25,94 @@
namespace dlib namespace dlib
{ {
template < // ----------------------------------------------------------------------------------------
typename server_base
> class http_parse_error : public error
class server_http_1 : public server_base
{ {
public:
http_parse_error(const std::string& str, int http_error_code_):
error(str),http_error_code(http_error_code_) {}
/*! const int http_error_code;
CONVENTION };
this extension doesn't add any new state to this object.
!*/
// ----------------------------------------------------------------------------------------
template <typename Key, typename Value>
class constmap : public std::map<Key, Value>
{
public: public:
const Value& operator[](const Key& k) const
server_http_1()
{ {
max_content_length = 10*1024*1024; // 10MB static const Value dummy = Value();
}
unsigned long get_max_content_length ( typename std::map<Key, Value>::const_iterator ci = std::map<Key, Value>::find(k);
) const { return max_content_length; }
void set_max_content_length ( if ( ci == this->end() )
unsigned long max_length return dummy;
) else
{ return ci->second;
max_content_length = max_length;
} }
template <typename Key, typename Value> Value& operator[](const Key& k)
class constmap : public std::map<Key, Value>
{
public:
const Value& operator[](const Key& k) const
{
static const Value dummy = Value();
typename std::map<Key, Value>::const_iterator ci = std::map<Key, Value>::find(k);
if ( ci == this->end() )
return dummy;
else
return ci->second;
}
Value& operator[](const Key& k)
{
return std::map<Key, Value>::operator [](k);
}
};
typedef constmap< std::string, std::string > key_value_map;
struct incoming_things
{ {
incoming_things() : foreign_port(0), local_port(0) {} return std::map<Key, Value>::operator [](k);
}
std::string path; };
std::string request_type;
std::string content_type;
std::string protocol;
std::string body;
key_value_map queries;
key_value_map cookies;
key_value_map headers;
std::string foreign_ip; typedef constmap< std::string, std::string > key_value_map;
unsigned short foreign_port;
std::string local_ip;
unsigned short local_port;
};
struct outgoing_things struct incoming_things
{ {
outgoing_things() : http_return(200) { } incoming_things (
const std::string& foreign_ip_,
const std::string& local_ip_,
unsigned short foreign_port_,
unsigned short local_port_
):
foreign_ip(foreign_ip_),
foreign_port(foreign_port_),
local_ip(local_ip_),
local_port(local_port_)
{}
std::string path;
std::string request_type;
std::string content_type;
std::string protocol;
std::string body;
key_value_map queries;
key_value_map cookies;
key_value_map headers;
std::string foreign_ip;
unsigned short foreign_port;
std::string local_ip;
unsigned short local_port;
};
key_value_map cookies; struct outgoing_things
key_value_map headers; {
unsigned short http_return; outgoing_things() : http_return(200), http_return_status("OK") { }
std::string http_return_status;
};
key_value_map cookies;
key_value_map headers;
unsigned short http_return;
std::string http_return_status;
};
private: // ----------------------------------------------------------------------------------------
virtual const std::string on_request (
const incoming_things& incoming,
outgoing_things& outgoing
) = 0;
unsigned char to_hex( unsigned char x ) const namespace http_impl
{
inline unsigned char to_hex( unsigned char x )
{ {
return x + (x > 9 ? ('A'-10) : '0'); return x + (x > 9 ? ('A'-10) : '0');
} }
const std::string urlencode( const std::string& s ) const inline const std::string urlencode( const std::string& s )
{ {
std::ostringstream os; std::ostringstream os;
...@@ -146,9 +137,9 @@ namespace dlib ...@@ -146,9 +137,9 @@ namespace dlib
return os.str(); return os.str();
} }
unsigned char from_hex ( inline unsigned char from_hex (
unsigned char ch unsigned char ch
) const )
{ {
if (ch <= '9' && ch >= '0') if (ch <= '9' && ch >= '0')
ch -= '0'; ch -= '0';
...@@ -161,9 +152,9 @@ namespace dlib ...@@ -161,9 +152,9 @@ namespace dlib
return ch; return ch;
} }
const std::string urldecode ( inline const std::string urldecode (
const std::string& str const std::string& str
) const )
{ {
using namespace std; using namespace std;
string result; string result;
...@@ -190,7 +181,10 @@ namespace dlib ...@@ -190,7 +181,10 @@ namespace dlib
return result; return result;
} }
void parse_url(std::string word, key_value_map& queries) inline void parse_url(
std::string word,
key_value_map& queries
)
/*! /*!
Parses the query string of a URL. word should be the stuff that comes Parses the query string of a URL. word should be the stuff that comes
after the ? in the query URL. after the ? in the query URL.
...@@ -220,11 +214,11 @@ namespace dlib ...@@ -220,11 +214,11 @@ namespace dlib
} }
} }
void read_with_limit( inline void read_with_limit(
std::istream& in, std::istream& in,
std::string& buffer, std::string& buffer,
int delim = '\n' int delim = '\n'
) const )
{ {
using namespace std; using namespace std;
const size_t max = 16*1024; const size_t max = 16*1024;
...@@ -236,6 +230,10 @@ namespace dlib ...@@ -236,6 +230,10 @@ namespace dlib
buffer += (char)in.get(); buffer += (char)in.get();
} }
// if we quit the loop because the data is longer than expected or we hit EOF
if (in.peek() == EOF || buffer.size() == max)
throw http_parse_error("HTTP field from client is too long", 414);
// Make sure the last char is the delim. // Make sure the last char is the delim.
if (in.get() != delim) if (in.get() != delim)
{ {
...@@ -252,238 +250,314 @@ namespace dlib ...@@ -252,238 +250,314 @@ namespace dlib
} }
} }
} }
}
void on_connect (
std::istream& in,
std::ostream& out,
const std::string& foreign_ip,
const std::string& local_ip,
unsigned short foreign_port,
unsigned short local_port,
uint64
)
{
bool my_fault = true;
using namespace std;
try inline unsigned long parse_http_request (
std::istream& in,
incoming_things& incoming,
unsigned long max_content_length
)
{
using namespace std;
using namespace http_impl;
read_with_limit(in, incoming.request_type, ' ');
// get the path
read_with_limit(in, incoming.path, ' ');
// Get the HTTP/1.1 - Ignore for now...
read_with_limit(in, incoming.protocol);
key_value_map& incoming_headers = incoming.headers;
key_value_map& cookies = incoming.cookies;
std::string& path = incoming.path;
std::string& content_type = incoming.content_type;
unsigned long content_length = 0;
string line;
read_with_limit(in, line);
string first_part_of_header;
string::size_type position_of_double_point;
// now loop over all the incoming_headers
while (line.size() > 2)
{
position_of_double_point = line.find_first_of(':');
if ( position_of_double_point != string::npos )
{ {
incoming_things incoming; first_part_of_header = dlib::trim(line.substr(0, position_of_double_point));
outgoing_things outgoing;
if ( !incoming_headers[first_part_of_header].empty() )
incoming_headers[ first_part_of_header ] += " ";
incoming_headers[first_part_of_header] += dlib::trim(line.substr(position_of_double_point+1));
incoming.foreign_ip = foreign_ip; // look for Content-Type:
incoming.foreign_port = foreign_port; if (line.size() > 14 && strings_equal_ignore_case(line, "Content-Type:", 13))
incoming.local_ip = local_ip;
incoming.local_port = local_port;
read_with_limit(in, incoming.request_type, ' ');
// get the path
read_with_limit(in, incoming.path, ' ');
// Get the HTTP/1.1 - Ignore for now...
read_with_limit(in, incoming.protocol);
key_value_map& incoming_headers = incoming.headers;
key_value_map& cookies = incoming.cookies;
std::string& path = incoming.path;
std::string& content_type = incoming.content_type;
unsigned long content_length = 0;
string line;
read_with_limit(in, line);
string first_part_of_header;
string::size_type position_of_double_point;
// now loop over all the incoming_headers
while (line.size() > 2)
{ {
position_of_double_point = line.find_first_of(':'); content_type = line.substr(14);
if ( position_of_double_point != string::npos ) if (content_type[content_type.size()-1] == '\r')
content_type.erase(content_type.size()-1);
}
// look for Content-Length:
else if (line.size() > 16 && strings_equal_ignore_case(line, "Content-Length:", 15))
{
istringstream sin(line.substr(16));
sin >> content_length;
if (!sin)
{ {
first_part_of_header = dlib::trim(line.substr(0, position_of_double_point)); throw http_parse_error("Invalid Content-Length of '" + line.substr(16) + "'", 411);
}
if ( !incoming_headers[first_part_of_header].empty() ) if (content_length > max_content_length)
incoming_headers[ first_part_of_header ] += " "; {
incoming_headers[first_part_of_header] += dlib::trim(line.substr(position_of_double_point+1)); std::ostringstream sout;
sout << "Content-Length of post back is too large. It must be less than " << max_content_length;
throw http_parse_error(sout.str(), 413);
}
}
// look for any cookies
else if (line.size() > 6 && strings_equal_ignore_case(line, "Cookie:", 7))
{
string::size_type pos = 6;
string key, value;
bool seen_key_start = false;
bool seen_equal_sign = false;
while (pos + 1 < line.size())
{
++pos;
// ignore whitespace between cookies
if (!seen_key_start && line[pos] == ' ')
continue;
// look for Content-Type: seen_key_start = true;
if (line.size() > 14 && strings_equal_ignore_case(line, "Content-Type:", 13)) if (!seen_equal_sign)
{
content_type = line.substr(14);
if (content_type[content_type.size()-1] == '\r')
content_type.erase(content_type.size()-1);
}
// look for Content-Length:
else if (line.size() > 16 && strings_equal_ignore_case(line, "Content-Length:", 15))
{
istringstream sin(line.substr(16));
sin >> content_length;
if (!sin)
content_length = 0;
}
// look for any cookies
else if (line.size() > 6 && strings_equal_ignore_case(line, "Cookie:", 7))
{ {
string::size_type pos = 6; if (line[pos] == '=')
string key, value;
bool seen_key_start = false;
bool seen_equal_sign = false;
while (pos + 1 < line.size())
{ {
++pos; seen_equal_sign = true;
// ignore whitespace between cookies
if (!seen_key_start && line[pos] == ' ')
continue;
seen_key_start = true;
if (!seen_equal_sign)
{
if (line[pos] == '=')
{
seen_equal_sign = true;
}
else
{
key += line[pos];
}
}
else
{
if (line[pos] == ';')
{
cookies[urldecode(key)] = urldecode(value);
seen_equal_sign = false;
seen_key_start = false;
key.clear();
value.clear();
}
else
{
value += line[pos];
}
}
} }
if (key.size() > 0) else
{
key += line[pos];
}
}
else
{
if (line[pos] == ';')
{ {
cookies[urldecode(key)] = urldecode(value); cookies[urldecode(key)] = urldecode(value);
seen_equal_sign = false;
seen_key_start = false;
key.clear(); key.clear();
value.clear(); value.clear();
} }
else
{
value += line[pos];
}
} }
} // no ':' in it! }
read_with_limit(in, line); if (key.size() > 0)
} // while (line.size() > 2 ) {
cookies[urldecode(key)] = urldecode(value);
// If there is data being posted back to us then load it into the incoming.body key.clear();
// string. value.clear();
if (content_length > max_content_length) }
{
dlog << LERROR << "Request from: " << foreign_ip << " - body content length " << content_length << " exceeded max content length of " << max_content_length;
in.setstate(ios::badbit);
}
else if ( content_length > 0)
{
incoming.body.resize(content_length);
in.read(&incoming.body[0],content_length);
} }
} // no ':' in it!
read_with_limit(in, line);
} // while (line.size() > 2 )
// If there is data being posted back to us as a query string then
// pick out the queries using parse_url.
if ((strings_equal_ignore_case(incoming.request_type, "POST") ||
strings_equal_ignore_case(incoming.request_type, "PUT")) &&
strings_equal_ignore_case(left_substr(content_type,";"), "application/x-www-form-urlencoded"))
{
parse_url(incoming.body, incoming.queries);
}
string::size_type pos = path.find_first_of("?"); // If there is data being posted back to us as a query string then
if (pos != string::npos) // pick out the queries using parse_url.
{ if ((strings_equal_ignore_case(incoming.request_type, "POST") ||
parse_url(path.substr(pos+1), incoming.queries); strings_equal_ignore_case(incoming.request_type, "PUT")) &&
} strings_equal_ignore_case(left_substr(content_type,";"), "application/x-www-form-urlencoded"))
{
if (content_length > 0)
{
incoming.body.resize(content_length);
in.read(&incoming.body[0],content_length);
}
parse_url(incoming.body, incoming.queries);
}
string::size_type pos = path.find_first_of("?");
if (pos != string::npos)
{
parse_url(path.substr(pos+1), incoming.queries);
}
my_fault = false;
key_value_map& new_cookies = outgoing.cookies;
key_value_map& response_headers = outgoing.headers;
// Set some defaults if (!in)
outgoing.http_return = 200; throw http_parse_error("Error parsing HTTP request", 500);
outgoing.http_return_status = "OK";
// if there wasn't a problem with the input stream at some point return content_length;
// then lets trigger this request callback. }
std::string result;
if (in)
{
result = on_request(incoming, outgoing);
}
else
{
dlog << LERROR << "Request from: " << foreign_ip << " - Invalid request - Request Entity Too Large";
outgoing.http_return = 413;
outgoing.http_return_status = "Request Entity Too Large";
}
my_fault = true;
// only send this header if the user hasn't told us to send another kind inline void read_body (
bool has_content_type(false), std::istream& in,
has_location(false); incoming_things& incoming
for( typename key_value_map::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci ) )
{ {
if ( !has_content_type && strings_equal_ignore_case(ci->first , "content-type") ) // if the body hasn't already been loaded and there is data to load
{ if (incoming.body.size() == 0 &&
has_content_type = true; incoming.headers.count("Content-Length") != 0)
} {
else if ( !has_location && strings_equal_ignore_case(ci->first , "location") ) const unsigned long content_length = string_cast<unsigned long>(incoming.headers["Content-Length"]);
{
has_location = true;
}
}
if ( has_location ) incoming.body.resize(content_length);
{ if (content_length > 0)
outgoing.http_return = 302; {
} in.read(&incoming.body[0],content_length);
}
}
}
if ( !has_content_type ) inline void write_http_response (
{ std::ostream& out,
response_headers["Content-Type"] = "text/html"; outgoing_things outgoing,
} const std::string& result
)
{
using namespace http_impl;
key_value_map& new_cookies = outgoing.cookies;
key_value_map& response_headers = outgoing.headers;
{ // only send this header if the user hasn't told us to send another kind
ostringstream os; bool has_content_type = false, has_location = false;
os << result.size(); for( typename key_value_map::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci )
{
if ( !has_content_type && strings_equal_ignore_case(ci->first , "content-type") )
{
has_content_type = true;
}
else if ( !has_location && strings_equal_ignore_case(ci->first , "location") )
{
has_location = true;
}
}
response_headers["Content-Length"] = os.str(); if ( has_location )
} {
outgoing.http_return = 302;
}
out << "HTTP/1.0 " << outgoing.http_return << " " << outgoing.http_return_status << "\r\n"; if ( !has_content_type )
{
response_headers["Content-Type"] = "text/html";
}
// Set any new headers response_headers["Content-Length"] = cast_to_string(result.size());
for( typename key_value_map::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci )
{
out << ci->first << ": " << ci->second << "\r\n";
}
// set any cookies out << "HTTP/1.0 " << outgoing.http_return << " " << outgoing.http_return_status << "\r\n";
for( typename key_value_map::const_iterator ci = new_cookies.begin(); ci != new_cookies.end(); ++ci )
{ // Set any new headers
out << "Set-Cookie: " << urlencode(ci->first) << '=' << urlencode(ci->second) << "\r\n"; for( typename key_value_map::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci )
} {
out << "\r\n" << result; out << ci->first << ": " << ci->second << "\r\n";
}
// set any cookies
for( typename key_value_map::const_iterator ci = new_cookies.begin(); ci != new_cookies.end(); ++ci )
{
out << "Set-Cookie: " << urlencode(ci->first) << '=' << urlencode(ci->second) << "\r\n";
}
out << "\r\n" << result;
}
inline void write_http_response (
std::ostream& out,
const http_parse_error& e
)
{
outgoing_things outgoing;
outgoing.http_return = e.http_error_code;
outgoing.http_return_status = e.what();
write_http_response(out, outgoing, std::string("Error processing request: ") + e.what());
}
inline void write_http_response (
std::ostream& out,
const std::exception& e
)
{
outgoing_things outgoing;
outgoing.http_return = 500;
outgoing.http_return_status = e.what();
write_http_response(out, outgoing, std::string("Error processing request: ") + e.what());
}
// ----------------------------------------------------------------------------------------
template <
typename server_base
>
class server_http_1 : public server_base
{
/*!
CONVENTION
this extension doesn't add any new state to this object.
!*/
public:
server_http_1()
{
max_content_length = 10*1024*1024; // 10MB
}
unsigned long get_max_content_length (
) const { return max_content_length; }
void set_max_content_length (
unsigned long max_length
)
{
max_content_length = max_length;
}
private:
virtual const std::string on_request (
const incoming_things& incoming,
outgoing_things& outgoing
) = 0;
void on_connect (
std::istream& in,
std::ostream& out,
const std::string& foreign_ip,
const std::string& local_ip,
unsigned short foreign_port,
unsigned short local_port,
uint64
)
{
try
{
incoming_things incoming(foreign_ip, local_ip, foreign_port, local_port);
outgoing_things outgoing;
parse_http_request(in, incoming, max_content_length);
read_body(in, incoming);
const std::string& result = on_request(incoming, outgoing);
write_http_response(out, outgoing, result);
} }
catch (std::bad_alloc&) catch (http_parse_error& e)
{ {
dlog << LERROR << "We ran out of memory in server_http::on_connect()"; dlog << LERROR << "Error processing request from: " << foreign_ip << " - " << e.what();
// If this is an escaped exception from on_request then let it fly! write_http_response(out, e);
// Seriously though, this way it is obvious to the user that something bad happened }
// since they probably won't have the dlib logger enabled. catch (std::exception& e)
if (!my_fault) {
throw; dlog << LERROR << "Error processing request from: " << foreign_ip << " - " << e.what();
write_http_response(out, e);
} }
} }
unsigned long max_content_length; unsigned long max_content_length;
...@@ -498,6 +572,3 @@ namespace dlib ...@@ -498,6 +572,3 @@ namespace dlib
#endif // DLIB_SERVER_HTTp_1_ #endif // DLIB_SERVER_HTTp_1_
...@@ -193,7 +193,8 @@ namespace dlib ...@@ -193,7 +193,8 @@ namespace dlib
- outgoing.http_return and outgoing.http_return_status may be set to override the - outgoing.http_return and outgoing.http_return_status may be set to override the
default HTTP return code of 200 OK default HTTP return code of 200 OK
throws throws
- does not throw any exceptions - throws only exceptions derived from std::exception. If an exception is thrown
then the error string from the exception is returned to the web browser.
!*/ !*/
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment