Commit dd3a00f9 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

fix process class on Windows and HIPRTC driver

parent 6aa6c954
......@@ -88,21 +88,38 @@ int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
constexpr std::size_t MIGRAPHX_PROCESS_BUFSIZE = 4096;
enum class direction
{
input,
output
};
template <direction dir>
class pipe
{
public:
explicit pipe(bool inherit_handle = true)
explicit pipe()
{
SECURITY_ATTRIBUTES attrs;
attrs.nLength = sizeof(SECURITY_ATTRIBUTES);
attrs.bInheritHandle = inherit_handle ? TRUE : FALSE;
attrs.bInheritHandle = TRUE;
attrs.lpSecurityDescriptor = nullptr;
if(CreatePipe(&m_read, &m_write, &attrs, 0) == FALSE)
throw GetLastError();
if(SetHandleInformation(&m_read, HANDLE_FLAG_INHERIT, 0) == FALSE)
throw GetLastError();
if(dir == direction::output)
{
// Do not inherit the read handle for the output pipe
if(SetHandleInformation(m_read, HANDLE_FLAG_INHERIT, 0) == 0)
throw GetLastError();
}
else
{
// Do not inherit the write handle for the input pipe
if(SetHandleInformation(m_write, HANDLE_FLAG_INHERIT, 0) == 0)
throw GetLastError();
}
}
pipe(const pipe&) = delete;
......@@ -112,10 +129,36 @@ class pipe
~pipe()
{
CloseHandle(m_read);
m_read = nullptr;
CloseHandle(m_write);
m_write = nullptr;
if(m_write != nullptr)
{
CloseHandle(m_write);
}
if(m_read != nullptr)
{
CloseHandle(m_read);
}
}
bool close_write_handle()
{
auto result = true;
if(m_write != nullptr)
{
result = CloseHandle(m_write) == TRUE;
m_write = nullptr;
}
return result;
}
bool close_read_handle()
{
auto result = true;
if(m_read != nullptr)
{
result = CloseHandle(m_read) == TRUE;
m_read = nullptr;
}
return result;
}
std::optional<std::pair<bool, DWORD>> read(LPVOID buffer, DWORD length) const
......@@ -148,42 +191,64 @@ class pipe
};
template <typename F>
int exec(const std::string& cmd, F f)
int exec(const std::pair<std::string, std::string>& command, F f)
{
auto& [cwd, cmd] = command;
if((cmd.length() + 1) > MAX_PATH)
MIGRAPHX_THROW("Command too long, required maximum " + std::to_string(MAX_PATH) +
" characters (including terminating null character)");
try
{
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
std::cout << "[cwd=" << cwd << "]; cmd='" << cmd << "'\n";
STARTUPINFO info;
PROCESS_INFORMATION process_info;
pipe in{}, out{};
pipe<direction::input> input{};
pipe<direction::output> output{};
ZeroMemory(&info, sizeof(STARTUPINFO));
info.cb = sizeof(STARTUPINFO);
info.hStdError = out.get_write_handle();
info.hStdOutput = out.get_write_handle();
info.hStdInput = in.get_read_handle();
info.hStdError = output.get_write_handle();
info.hStdOutput = output.get_write_handle();
info.hStdInput = input.get_read_handle();
info.dwFlags |= STARTF_USESTDHANDLES;
TCHAR cmdline[MAX_PATH];
std::copy(std::begin(cmd), std::end(cmd), std::begin(cmdline));
ZeroMemory(&process_info, sizeof(process_info));
if(CreateProcess(nullptr,
const_cast<LPSTR>(cmd.c_str()),
cmdline,
nullptr,
nullptr,
TRUE,
0,
nullptr,
nullptr,
cwd.empty() ? nullptr : static_cast<LPCSTR>(cwd.c_str()),
&info,
&process_info) == FALSE)
{
return GetLastError();
MIGRAPHX_THROW("Error creating process (" + std::to_string(GetLastError()) + ")");
}
f(in, out);
if(not output.close_write_handle())
MIGRAPHX_THROW("Error closing STDOUT handle for writing (" +
std::to_string(GetLastError()) + ")");
if(not input.close_read_handle())
MIGRAPHX_THROW("Error closing STDIN handle for reading (" +
std::to_string(GetLastError()) + ")");
f(input, output);
if(not input.close_write_handle())
MIGRAPHX_THROW("Error closing STDIN handle for writing (" +
std::to_string(GetLastError()) + ")");
WaitForSingleObject(process_info.hProcess, INFINITE);
......@@ -202,18 +267,18 @@ int exec(const std::string& cmd, F f)
}
}
int exec(const std::string& cmd)
int exec(const std::pair<std::string, std::string>& cmd)
{
TCHAR buffer[MIGRAPHX_PROCESS_BUFSIZE];
HANDLE std_out{GetStdHandle(STD_OUTPUT_HANDLE)};
return (std_out == nullptr or std_out == INVALID_HANDLE_VALUE)
? GetLastError()
: exec(cmd, [&](const pipe&, const pipe& out) {
: exec(cmd, [&](const pipe<direction::input>&, const pipe<direction::output>& out) {
for(;;)
{
if(auto result = out.read(buffer, MIGRAPHX_PROCESS_BUFSIZE))
{
auto [more_data, bytes_read] = *result;
auto& [more_data, bytes_read] = *result;
if(not more_data or bytes_read == 0)
break;
DWORD written;
......@@ -224,10 +289,11 @@ int exec(const std::string& cmd)
});
}
int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
int exec(const std::pair<std::string, std::string>& cmd,
std::function<void(process::writer)> std_in)
{
return exec(cmd, [&](const pipe& in, const pipe&) {
std_in([&](const char* buffer, std::size_t n) { in.write(buffer, n); });
return exec(cmd, [&](const pipe<direction::input>& input, const pipe<direction::output>&) {
std_in([&](const char* buffer, std::size_t n) { input.write(buffer, n); });
});
}
......@@ -238,6 +304,10 @@ struct process_impl
std::string command{};
fs::path cwd{};
#ifdef _WIN32
std::pair<std::string, std::string> get_params() const { return {cwd.string(), command}; }
#endif
std::string get_command() const
{
std::string result;
......@@ -283,13 +353,13 @@ void process::exec()
#ifndef _WIN32
impl->check_exec(impl->get_command(), redirect_to(std::cout));
#else
impl->check_exec(impl->get_command());
impl->check_exec(impl->get_params());
#endif
}
void process::write(std::function<void(process::writer)> pipe_in)
{
impl->check_exec(impl->get_command(), std::move(pipe_in));
impl->check_exec(impl->get_params(), std::move(pipe_in));
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -31,10 +31,31 @@
#include <iostream>
#include <cstring>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#endif
std::vector<char> read_stdin()
{
std::vector<char> result;
#ifdef _WIN32
HANDLE std_in = GetStdHandle(STD_INPUT_HANDLE);
if(std_in == INVALID_HANDLE_VALUE)
MIGRAPHX_THROW("STDIN invalid handle (" + std::to_string(GetLastError()) + ")");
constexpr std::size_t BUFFER_SIZE = 1024;
DWORD bytes_read;
TCHAR buffer[BUFFER_SIZE];
for(;;)
{
BOOL status = ReadFile(std_in, buffer, BUFFER_SIZE, &bytes_read, nullptr);
if(status == FALSE or bytes_read == 0)
break;
result.insert(result.end(), buffer, buffer + bytes_read);
}
#else
std::array<char, 1024> buffer;
std::size_t len = 0;
while((len = std::fread(buffer.data(), 1, buffer.size(), stdin)) > 0)
......@@ -44,6 +65,7 @@ std::vector<char> read_stdin()
result.insert(result.end(), buffer.data(), buffer.data() + len);
}
#endif
return result;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment