Commit bfe935a9 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

fix process class on Windows and HIPRTC driver

parent 855e099b
...@@ -94,7 +94,7 @@ enum class direction ...@@ -94,7 +94,7 @@ enum class direction
output output
}; };
template <direction dir, bool inherit_handle = true> template <direction dir>
class pipe class pipe
{ {
public: public:
...@@ -102,26 +102,23 @@ class pipe ...@@ -102,26 +102,23 @@ class pipe
{ {
SECURITY_ATTRIBUTES attrs; SECURITY_ATTRIBUTES attrs;
attrs.nLength = sizeof(SECURITY_ATTRIBUTES); attrs.nLength = sizeof(SECURITY_ATTRIBUTES);
attrs.bInheritHandle = inherit_handle ? TRUE : FALSE; attrs.bInheritHandle = TRUE;
attrs.lpSecurityDescriptor = nullptr; attrs.lpSecurityDescriptor = nullptr;
if(CreatePipe(&m_read, &m_write, &attrs, 0) == FALSE) if(CreatePipe(&m_read, &m_write, &attrs, 0) == FALSE)
throw GetLastError(); throw GetLastError();
if(inherit_handle) if(dir == direction::output)
{ {
if(dir == direction::output) // Do not inherit the read handle for the output pipe
{ if(SetHandleInformation(m_read, HANDLE_FLAG_INHERIT, 0) == 0)
// Do not inherit the read handle for the output pipe throw GetLastError();
if(SetHandleInformation(m_read, HANDLE_FLAG_INHERIT, 0) == 0) }
throw GetLastError(); else
} {
else // Do not inherit the write handle for the input pipe
{ if(SetHandleInformation(m_write, HANDLE_FLAG_INHERIT, 0) == 0)
// Do not inherit the write handle for the input pipe throw GetLastError();
if(SetHandleInformation(m_write, HANDLE_FLAG_INHERIT, 0) == 0)
throw GetLastError();
}
} }
} }
...@@ -132,31 +129,36 @@ class pipe ...@@ -132,31 +129,36 @@ class pipe
~pipe() ~pipe()
{ {
if(m_read != nullptr)
CloseHandle(m_read);
if(m_write != nullptr) if(m_write != nullptr)
{
CloseHandle(m_write); CloseHandle(m_write);
} }
void close_read_handle()
{
if(m_read != nullptr) if(m_read != nullptr)
{ {
if(CloseHandle(m_read) == 0) CloseHandle(m_read);
MIGRAPHX_THROW("Error closing read handle: " + std::to_string(GetLastError()));
m_read = nullptr;
} }
} }
void close_write_handle() bool close_write_handle()
{ {
auto result = true;
if(m_write != nullptr) if(m_write != nullptr)
{ {
if(CloseHandle(m_write) == 0) result = CloseHandle(m_write) == TRUE;
MIGRAPHX_THROW("Error closing write handle: " + std::to_string(GetLastError()));
m_write = nullptr; m_write = nullptr;
} }
return result;
}
bool close_read_handle()
{
auto result = true;
if(m_read != nullptr)
{
result = CloseHandle(m_read) == TRUE;
m_read = nullptr;
}
return result;
} }
std::optional<std::pair<bool, DWORD>> read(LPVOID buffer, DWORD length) const std::optional<std::pair<bool, DWORD>> read(LPVOID buffer, DWORD length) const
...@@ -189,12 +191,18 @@ class pipe ...@@ -189,12 +191,18 @@ class pipe
}; };
template <typename F> template <typename F>
int exec(const std::string& cmd, F f) int exec(const std::pair<std::string, std::string>& command, F f)
{ {
auto& [cwd, cmd] = command;
if((cmd.length() + 1) > MAX_PATH)
MIGRAPHX_THROW("Command too long, required maximum " + std::to_string(MAX_PATH) +
" characters (including terminating null character)");
try try
{ {
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << "[cwd=" << cwd << "]; cmd='" << cmd << "'\n";
STARTUPINFO info; STARTUPINFO info;
PROCESS_INFORMATION process_info; PROCESS_INFORMATION process_info;
...@@ -209,24 +217,39 @@ int exec(const std::string& cmd, F f) ...@@ -209,24 +217,39 @@ int exec(const std::string& cmd, F f)
info.hStdInput = input.get_read_handle(); info.hStdInput = input.get_read_handle();
info.dwFlags |= STARTF_USESTDHANDLES; info.dwFlags |= STARTF_USESTDHANDLES;
TCHAR cmdline[MAX_PATH];
std::copy(std::begin(cmd), std::end(cmd), std::begin(cmdline));
ZeroMemory(&process_info, sizeof(process_info)); ZeroMemory(&process_info, sizeof(process_info));
if(CreateProcess(nullptr, if(CreateProcess(nullptr,
const_cast<LPSTR>(cmd.c_str()), cmdline,
nullptr, nullptr,
nullptr, nullptr,
TRUE, TRUE,
0, 0,
nullptr, nullptr,
nullptr, cwd.empty() ? nullptr : static_cast<LPCSTR>(cwd.c_str()),
&info, &info,
&process_info) == FALSE) &process_info) == FALSE)
{ {
return GetLastError(); MIGRAPHX_THROW("Error creating process (" + std::to_string(GetLastError()) + ")");
} }
if(not output.close_write_handle())
MIGRAPHX_THROW("Error closing STDOUT handle for writing (" +
std::to_string(GetLastError()) + ")");
if(not input.close_read_handle())
MIGRAPHX_THROW("Error closing STDIN handle for reading (" +
std::to_string(GetLastError()) + ")");
f(input, output); f(input, output);
if(not input.close_write_handle())
MIGRAPHX_THROW("Error closing STDIN handle for writing (" +
std::to_string(GetLastError()) + ")");
WaitForSingleObject(process_info.hProcess, INFINITE); WaitForSingleObject(process_info.hProcess, INFINITE);
DWORD status{}; DWORD status{};
...@@ -235,9 +258,6 @@ int exec(const std::string& cmd, F f) ...@@ -235,9 +258,6 @@ int exec(const std::string& cmd, F f)
CloseHandle(process_info.hProcess); CloseHandle(process_info.hProcess);
CloseHandle(process_info.hThread); CloseHandle(process_info.hThread);
input.close_read_handle();
output.close_write_handle();
return static_cast<int>(status); return static_cast<int>(status);
} }
// cppcheck-suppress catchExceptionByValue // cppcheck-suppress catchExceptionByValue
...@@ -247,7 +267,7 @@ int exec(const std::string& cmd, F f) ...@@ -247,7 +267,7 @@ int exec(const std::string& cmd, F f)
} }
} }
int exec(const std::string& cmd) int exec(const std::pair<std::string, std::string>& cmd)
{ {
TCHAR buffer[MIGRAPHX_PROCESS_BUFSIZE]; TCHAR buffer[MIGRAPHX_PROCESS_BUFSIZE];
HANDLE std_out{GetStdHandle(STD_OUTPUT_HANDLE)}; HANDLE std_out{GetStdHandle(STD_OUTPUT_HANDLE)};
...@@ -258,7 +278,7 @@ int exec(const std::string& cmd) ...@@ -258,7 +278,7 @@ int exec(const std::string& cmd)
{ {
if(auto result = out.read(buffer, MIGRAPHX_PROCESS_BUFSIZE)) if(auto result = out.read(buffer, MIGRAPHX_PROCESS_BUFSIZE))
{ {
auto [more_data, bytes_read] = *result; auto& [more_data, bytes_read] = *result;
if(not more_data or bytes_read == 0) if(not more_data or bytes_read == 0)
break; break;
DWORD written; DWORD written;
...@@ -269,10 +289,11 @@ int exec(const std::string& cmd) ...@@ -269,10 +289,11 @@ int exec(const std::string& cmd)
}); });
} }
int exec(const std::string& cmd, std::function<void(process::writer)> std_in) int exec(const std::pair<std::string, std::string>& cmd,
std::function<void(process::writer)> std_in)
{ {
return exec(cmd, [&](const pipe<direction::input>& in, const pipe<direction::output>&) { return exec(cmd, [&](const pipe<direction::input>& input, const pipe<direction::output>&) {
std_in([&](const char* buffer, std::size_t n) { in.write(buffer, n); }); std_in([&](const char* buffer, std::size_t n) { input.write(buffer, n); });
}); });
} }
...@@ -283,6 +304,10 @@ struct process_impl ...@@ -283,6 +304,10 @@ struct process_impl
std::string command{}; std::string command{};
fs::path cwd{}; fs::path cwd{};
#ifdef _WIN32
std::pair<std::string, std::string> get_params() const { return {cwd.string(), command}; }
#endif
std::string get_command() const std::string get_command() const
{ {
std::string result; std::string result;
...@@ -328,13 +353,13 @@ void process::exec() ...@@ -328,13 +353,13 @@ void process::exec()
#ifndef _WIN32 #ifndef _WIN32
impl->check_exec(impl->get_command(), redirect_to(std::cout)); impl->check_exec(impl->get_command(), redirect_to(std::cout));
#else #else
impl->check_exec(impl->get_command()); impl->check_exec(impl->get_params());
#endif #endif
} }
void process::write(std::function<void(process::writer)> pipe_in) void process::write(std::function<void(process::writer)> pipe_in)
{ {
impl->check_exec(impl->get_command(), std::move(pipe_in)); impl->check_exec(impl->get_params(), std::move(pipe_in));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -31,10 +31,31 @@ ...@@ -31,10 +31,31 @@
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#endif
std::vector<char> read_stdin() std::vector<char> read_stdin()
{ {
std::vector<char> result; std::vector<char> result;
#ifdef _WIN32
HANDLE std_in = GetStdHandle(STD_INPUT_HANDLE);
if(std_in == INVALID_HANDLE_VALUE)
MIGRAPHX_THROW("STDIN invalid handle (" + std::to_string(GetLastError()) + ")");
constexpr std::size_t BUFFER_SIZE = 1024;
DWORD bytes_read;
TCHAR buffer[BUFFER_SIZE];
for(;;)
{
BOOL status = ReadFile(std_in, buffer, BUFFER_SIZE, &bytes_read, nullptr);
if(status == FALSE or bytes_read == 0)
break;
result.insert(result.end(), buffer, buffer + bytes_read);
}
#else
std::array<char, 1024> buffer; std::array<char, 1024> buffer;
std::size_t len = 0; std::size_t len = 0;
while((len = std::fread(buffer.data(), 1, buffer.size(), stdin)) > 0) while((len = std::fread(buffer.data(), 1, buffer.size(), stdin)) > 0)
...@@ -44,6 +65,7 @@ std::vector<char> read_stdin() ...@@ -44,6 +65,7 @@ std::vector<char> read_stdin()
result.insert(result.end(), buffer.data(), buffer.data() + len); result.insert(result.end(), buffer.data(), buffer.data() + len);
} }
#endif
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment