Unverified Commit 1e621a58 authored by Hu Yaoqi's avatar Hu Yaoqi Committed by GitHub
Browse files

fix: resolve issue with inability to correctly specify non-zero GPUs in multi-GPU systems (#404)

* Fix: Correctly specify non-zero GPUs in multi-GPU environments

This commit resolves an issue where the Nunchaku model could not be
correctly initialized and run on a user-specified non-zero GPU in
multi-GPU systems.

Key changes include:
- Using CUDADeviceContext in the FluxModel constructor to ensure
  the model and its submodules are created within the specified GPU context.
- Modifying the logic in FluxModel::forward for copying residual data
  from CPU back to GPU, ensuring it returns to the correct original GPU device.
- Adding explicit CUDA context management in Tensor::copy_ for data
  copy operations involving CUDA devices (H2D, D2H, D2D) to guarantee
  cudaMemcpyAsync executes on the correct device.

These changes allow users to reliably run Nunchaku on any specified
GPU in a multi-GPU setup.

* finish pre-commit
parent 3eabbd06
...@@ -778,6 +778,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -778,6 +778,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
: dtype(dtype), offload(offload) { : dtype(dtype), offload(offload) {
CUDADeviceContext model_construction_ctx(device.idx);
for (int i = 0; i < 19; i++) { for (int i = 0; i < 19; i++) {
transformer_blocks.push_back( transformer_blocks.push_back(
std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device)); std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
......
...@@ -432,6 +432,13 @@ public: ...@@ -432,6 +432,13 @@ public:
return *this; return *this;
} }
std::optional<CUDADeviceContext> operation_ctx_guard;
if (this->device().type == Device::CUDA) {
} else if (other.device().type == Device::CUDA) {
operation_ctx_guard.emplace(other.device().idx);
}
if (this->device().type == Device::CPU && other.device().type == Device::CPU) { if (this->device().type == Device::CPU && other.device().type == Device::CPU) {
memcpy(data_ptr<char>(), other.data_ptr<char>(), shape.size() * scalar_size()); memcpy(data_ptr<char>(), other.data_ptr<char>(), shape.size() * scalar_size());
return *this; return *this;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment