Commit 888e25fb authored by Artur Wojcik's avatar Artur Wojcik
Browse files

fix process class on Windows and HIPRTC driver

parent 0381151b
...@@ -88,21 +88,38 @@ int exec(const std::string& cmd, std::function<void(process::writer)> std_in) ...@@ -88,21 +88,38 @@ int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
constexpr std::size_t MIGRAPHX_PROCESS_BUFSIZE = 4096; constexpr std::size_t MIGRAPHX_PROCESS_BUFSIZE = 4096;
enum class direction
{
input,
output
};
template <direction dir>
class pipe class pipe
{ {
public: public:
explicit pipe(bool inherit_handle = true) explicit pipe()
{ {
SECURITY_ATTRIBUTES attrs; SECURITY_ATTRIBUTES attrs;
attrs.nLength = sizeof(SECURITY_ATTRIBUTES); attrs.nLength = sizeof(SECURITY_ATTRIBUTES);
attrs.bInheritHandle = inherit_handle ? TRUE : FALSE; attrs.bInheritHandle = TRUE;
attrs.lpSecurityDescriptor = nullptr; attrs.lpSecurityDescriptor = nullptr;
if(CreatePipe(&m_read, &m_write, &attrs, 0) == FALSE) if(CreatePipe(&m_read, &m_write, &attrs, 0) == FALSE)
throw GetLastError(); throw GetLastError();
if(SetHandleInformation(&m_read, HANDLE_FLAG_INHERIT, 0) == FALSE) if(dir == direction::output)
throw GetLastError(); {
// Do not inherit the read handle for the output pipe
if(SetHandleInformation(m_read, HANDLE_FLAG_INHERIT, 0) == 0)
throw GetLastError();
}
else
{
// Do not inherit the write handle for the input pipe
if(SetHandleInformation(m_write, HANDLE_FLAG_INHERIT, 0) == 0)
throw GetLastError();
}
} }
pipe(const pipe&) = delete; pipe(const pipe&) = delete;
...@@ -112,10 +129,36 @@ class pipe ...@@ -112,10 +129,36 @@ class pipe
~pipe() ~pipe()
{ {
CloseHandle(m_read); if(m_write != nullptr)
m_read = nullptr; {
CloseHandle(m_write); CloseHandle(m_write);
m_write = nullptr; }
if(m_read != nullptr)
{
CloseHandle(m_read);
}
}
bool close_write_handle()
{
auto result = true;
if(m_write != nullptr)
{
result = CloseHandle(m_write) == TRUE;
m_write = nullptr;
}
return result;
}
bool close_read_handle()
{
auto result = true;
if(m_read != nullptr)
{
result = CloseHandle(m_read) == TRUE;
m_read = nullptr;
}
return result;
} }
std::optional<std::pair<bool, DWORD>> read(LPVOID buffer, DWORD length) const std::optional<std::pair<bool, DWORD>> read(LPVOID buffer, DWORD length) const
...@@ -148,42 +191,64 @@ class pipe ...@@ -148,42 +191,64 @@ class pipe
}; };
template <typename F> template <typename F>
int exec(const std::string& cmd, F f) int exec(const std::pair<std::string, std::string>& command, F f)
{ {
auto& [cwd, cmd] = command;
if((cmd.length() + 1) > MAX_PATH)
MIGRAPHX_THROW("Command too long, required maximum " + std::to_string(MAX_PATH) +
" characters (including terminating null character)");
try try
{ {
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << "[cwd=" << cwd << "]; cmd='" << cmd << "'\n";
STARTUPINFO info; STARTUPINFO info;
PROCESS_INFORMATION process_info; PROCESS_INFORMATION process_info;
pipe in{}, out{}; pipe<direction::input> input{};
pipe<direction::output> output{};
ZeroMemory(&info, sizeof(STARTUPINFO)); ZeroMemory(&info, sizeof(STARTUPINFO));
info.cb = sizeof(STARTUPINFO); info.cb = sizeof(STARTUPINFO);
info.hStdError = out.get_write_handle(); info.hStdError = output.get_write_handle();
info.hStdOutput = out.get_write_handle(); info.hStdOutput = output.get_write_handle();
info.hStdInput = in.get_read_handle(); info.hStdInput = input.get_read_handle();
info.dwFlags |= STARTF_USESTDHANDLES; info.dwFlags |= STARTF_USESTDHANDLES;
TCHAR cmdline[MAX_PATH];
std::copy(std::begin(cmd), std::end(cmd), std::begin(cmdline));
ZeroMemory(&process_info, sizeof(process_info)); ZeroMemory(&process_info, sizeof(process_info));
if(CreateProcess(nullptr, if(CreateProcess(nullptr,
const_cast<LPSTR>(cmd.c_str()), cmdline,
nullptr, nullptr,
nullptr, nullptr,
TRUE, TRUE,
0, 0,
nullptr, nullptr,
nullptr, cwd.empty() ? nullptr : static_cast<LPCSTR>(cwd.c_str()),
&info, &info,
&process_info) == FALSE) &process_info) == FALSE)
{ {
return GetLastError(); MIGRAPHX_THROW("Error creating process (" + std::to_string(GetLastError()) + ")");
} }
f(in, out); if(not output.close_write_handle())
MIGRAPHX_THROW("Error closing STDOUT handle for writing (" +
std::to_string(GetLastError()) + ")");
if(not input.close_read_handle())
MIGRAPHX_THROW("Error closing STDIN handle for reading (" +
std::to_string(GetLastError()) + ")");
f(input, output);
if(not input.close_write_handle())
MIGRAPHX_THROW("Error closing STDIN handle for writing (" +
std::to_string(GetLastError()) + ")");
WaitForSingleObject(process_info.hProcess, INFINITE); WaitForSingleObject(process_info.hProcess, INFINITE);
...@@ -196,24 +261,24 @@ int exec(const std::string& cmd, F f) ...@@ -196,24 +261,24 @@ int exec(const std::string& cmd, F f)
return static_cast<int>(status); return static_cast<int>(status);
} }
// cppcheck-suppress catchExceptionByValue // cppcheck-suppress catchExceptionByValue
catch(DWORD last_error) catch(DWORD error)
{ {
return last_error; MIGRAPHX_THROW("Error spawning process (" + std::to_string(error) + ")");
} }
} }
int exec(const std::string& cmd) int exec(const std::pair<std::string, std::string>& cmd)
{ {
TCHAR buffer[MIGRAPHX_PROCESS_BUFSIZE]; TCHAR buffer[MIGRAPHX_PROCESS_BUFSIZE];
HANDLE std_out{GetStdHandle(STD_OUTPUT_HANDLE)}; HANDLE std_out{GetStdHandle(STD_OUTPUT_HANDLE)};
return (std_out == nullptr or std_out == INVALID_HANDLE_VALUE) return (std_out == nullptr or std_out == INVALID_HANDLE_VALUE)
? GetLastError() ? GetLastError()
: exec(cmd, [&](const pipe&, const pipe& out) { : exec(cmd, [&](const pipe<direction::input>&, const pipe<direction::output>& out) {
for(;;) for(;;)
{ {
if(auto result = out.read(buffer, MIGRAPHX_PROCESS_BUFSIZE)) if(auto result = out.read(buffer, MIGRAPHX_PROCESS_BUFSIZE))
{ {
auto [more_data, bytes_read] = *result; auto& [more_data, bytes_read] = *result;
if(not more_data or bytes_read == 0) if(not more_data or bytes_read == 0)
break; break;
DWORD written; DWORD written;
...@@ -224,10 +289,11 @@ int exec(const std::string& cmd) ...@@ -224,10 +289,11 @@ int exec(const std::string& cmd)
}); });
} }
int exec(const std::string& cmd, std::function<void(process::writer)> std_in) int exec(const std::pair<std::string, std::string>& cmd,
std::function<void(process::writer)> std_in)
{ {
return exec(cmd, [&](const pipe& in, const pipe&) { return exec(cmd, [&](const pipe<direction::input>& input, const pipe<direction::output>&) {
std_in([&](const char* buffer, std::size_t n) { in.write(buffer, n); }); std_in([&](const char* buffer, std::size_t n) { input.write(buffer, n); });
}); });
} }
...@@ -238,6 +304,10 @@ struct process_impl ...@@ -238,6 +304,10 @@ struct process_impl
std::string command{}; std::string command{};
fs::path cwd{}; fs::path cwd{};
#ifdef _WIN32
std::pair<std::string, std::string> get_params() const { return {cwd.string(), command}; }
#endif
std::string get_command() const std::string get_command() const
{ {
std::string result; std::string result;
...@@ -283,13 +353,13 @@ void process::exec() ...@@ -283,13 +353,13 @@ void process::exec()
#ifndef _WIN32 #ifndef _WIN32
impl->check_exec(impl->get_command(), redirect_to(std::cout)); impl->check_exec(impl->get_command(), redirect_to(std::cout));
#else #else
impl->check_exec(impl->get_command()); impl->check_exec(impl->get_params());
#endif #endif
} }
void process::write(std::function<void(process::writer)> pipe_in) void process::write(std::function<void(process::writer)> pipe_in)
{ {
impl->check_exec(impl->get_command(), std::move(pipe_in)); impl->check_exec(impl->get_params(), std::move(pipe_in));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -31,10 +31,31 @@ ...@@ -31,10 +31,31 @@
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#endif
std::vector<char> read_stdin() std::vector<char> read_stdin()
{ {
std::vector<char> result; std::vector<char> result;
#ifdef _WIN32
HANDLE std_in = GetStdHandle(STD_INPUT_HANDLE);
if(std_in == INVALID_HANDLE_VALUE)
MIGRAPHX_THROW("STDIN invalid handle (" + std::to_string(GetLastError()) + ")");
constexpr std::size_t BUFFER_SIZE = 1024;
DWORD bytes_read;
TCHAR buffer[BUFFER_SIZE];
for(;;)
{
BOOL status = ReadFile(std_in, buffer, BUFFER_SIZE, &bytes_read, nullptr);
if(status == FALSE or bytes_read == 0)
break;
result.insert(result.end(), buffer, buffer + bytes_read);
}
#else
std::array<char, 1024> buffer; std::array<char, 1024> buffer;
std::size_t len = 0; std::size_t len = 0;
while((len = std::fread(buffer.data(), 1, buffer.size(), stdin)) > 0) while((len = std::fread(buffer.data(), 1, buffer.size(), stdin)) > 0)
...@@ -44,6 +65,7 @@ std::vector<char> read_stdin() ...@@ -44,6 +65,7 @@ std::vector<char> read_stdin()
result.insert(result.end(), buffer.data(), buffer.data() + len); result.insert(result.end(), buffer.data(), buffer.data() + len);
} }
#endif
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment