Commit 27232e7b authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] add setDevice & load weights from pytorch

parent 0b1891cd
...@@ -10,10 +10,13 @@ ...@@ -10,10 +10,13 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder { class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public: public:
void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) { void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel"); spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
if (offload) { if (offload) {
spdlog::info("Layer offloading enabled"); spdlog::info("Layer offloading enabled");
} }
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
...@@ -27,6 +30,7 @@ public: ...@@ -27,6 +30,7 @@ public:
bool skip_first_layer = false) bool skip_first_layer = false)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward"); spdlog::debug("QuantizedFluxModel forward");
...@@ -61,6 +65,8 @@ public: ...@@ -61,6 +65,8 @@ public:
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context) torch::Tensor rotary_emb_context)
{ {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
...@@ -91,6 +97,8 @@ public: ...@@ -91,6 +97,8 @@ public:
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_single) torch::Tensor rotary_emb_single)
{ {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
...@@ -117,6 +125,8 @@ public: ...@@ -117,6 +125,8 @@ public:
throw std::invalid_argument("skipRanks must be multiples of 16"); throw std::invalid_argument("skipRanks must be multiples of 16");
} }
CUDADeviceContext ctx(deviceId);
spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks); spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
net->traverse([&](Module *module) { net->traverse([&](Module *module) {
......
...@@ -9,7 +9,12 @@ ...@@ -9,7 +9,12 @@
template<typename M> template<typename M>
class ModuleWrapper { class ModuleWrapper {
public: public:
void init(int deviceId) {
this->deviceId = deviceId;
}
void reset() { void reset() {
CUDADeviceContext ctx(this->deviceId);
debugContext.reset(); debugContext.reset();
net.reset(); net.reset();
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -20,6 +25,7 @@ public: ...@@ -20,6 +25,7 @@ public:
void load(std::string path, bool partial = false) { void load(std::string path, bool partial = false) {
checkModel(); checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path); spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
...@@ -30,6 +36,19 @@ public: ...@@ -30,6 +36,19 @@ public:
spdlog::info("Done."); spdlog::info("Done.");
} }
void loadDict(std::map<std::string, torch::Tensor> dict, bool partial = false) {
checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
void startDebug() { void startDebug() {
debugContext = std::make_unique<DebugContext>(); debugContext = std::make_unique<DebugContext>();
} }
...@@ -38,6 +57,8 @@ public: ...@@ -38,6 +57,8 @@ public:
} }
auto getDebugResults() { auto getDebugResults() {
CUDADeviceContext ctx(this->deviceId);
std::map<std::string, torch::Tensor> result; std::map<std::string, torch::Tensor> result;
if (debugContext) { if (debugContext) {
...@@ -59,4 +80,6 @@ protected: ...@@ -59,4 +80,6 @@ protected:
protected: protected:
std::unique_ptr<M> net; std::unique_ptr<M> net;
std::unique_ptr<DebugContext> debugContext; std::unique_ptr<DebugContext> debugContext;
int deviceId = -1;
}; };
\ No newline at end of file
...@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("path"), py::arg("path"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward) .def("forward", &QuantizedFluxModel::forward)
.def("forward_layer", &QuantizedFluxModel::forward_layer) .def("forward_layer", &QuantizedFluxModel::forward_layer)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
...@@ -45,6 +49,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -45,6 +49,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("path"), py::arg("path"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward) .def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer) .def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug) .def("startDebug", &QuantizedSanaModel::startDebug)
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
class QuantizedSanaModel : public ModuleWrapper<SanaModel> { class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public: public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) { void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel"); spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
SanaConfig cfg{ SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(), .num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(), .num_attention_heads = config["num_attention_heads"].cast<int>(),
...@@ -19,6 +19,9 @@ public: ...@@ -19,6 +19,9 @@ public:
.pag_layers = pag_layers, .pag_layers = pag_layers,
.use_fp4 = use_fp4, .use_fp4 = use_fp4,
}; };
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
...@@ -34,6 +37,7 @@ public: ...@@ -34,6 +37,7 @@ public:
bool cfg) bool cfg)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward"); spdlog::debug("QuantizedSanaModel forward");
...@@ -72,6 +76,7 @@ public: ...@@ -72,6 +76,7 @@ public:
bool cfg) bool cfg)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward_layer {}", idx); spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
......
...@@ -81,7 +81,8 @@ public: ...@@ -81,7 +81,8 @@ public:
BufferCUDA(size_t size) { BufferCUDA(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CUDA; this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx)); // checkCUDA(cudaGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice();
if (size == 0) { if (size == 0) {
this->ptr = nullptr; this->ptr = nullptr;
} }
...@@ -418,6 +419,7 @@ public: ...@@ -418,6 +419,7 @@ public:
result.buffer = std::make_shared<BufferMalloc>(shape.size() * scalarSize.at(scalarType)); result.buffer = std::make_shared<BufferMalloc>(shape.size() * scalarSize.at(scalarType));
} else if (device.type == Device::CUDA) { } else if (device.type == Device::CUDA) {
// TODO: cross device allocate // TODO: cross device allocate
CUDADeviceContext ctx(device.idx);
result.buffer = std::make_shared<BufferCUDA>(shape.size() * scalarSize.at(scalarType)); result.buffer = std::make_shared<BufferCUDA>(shape.size() * scalarSize.at(scalarType));
} else { } else {
assert(false); assert(false);
...@@ -429,6 +431,7 @@ public: ...@@ -429,6 +431,7 @@ public:
if (device.type == Device::CPU) { if (device.type == Device::CPU) {
memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize()); memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize());
} else if (device.type == Device::CUDA) { } else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx);
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream())); checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
} }
} }
......
...@@ -107,16 +107,97 @@ struct CUDAEventWrapper { ...@@ -107,16 +107,97 @@ struct CUDAEventWrapper {
} }
}; };
/**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device
* 3. hold one with `disableCache` when calling external code that may change the device
*/
class CUDADeviceContext {
public:
CUDADeviceContext(int device = -1, bool disableCache = false) : disableCache(disableCache) {
if (cacheDisabled()) {
// no previous context => we might entered from external code, reset cache
// previous context is reset on => external code may be executed, reset
currentDeviceCache = -1;
}
ctxs.push(this);
lastDevice = getDevice();
if (device >= 0) {
setDevice(device);
}
if (disableCache) {
// we are about to call external code, reset cache
currentDeviceCache = -1;
}
}
CUDADeviceContext(const CUDADeviceContext &) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete;
~CUDADeviceContext() {
if (disableCache) {
// retured from external code, cache is not reliable, reset
currentDeviceCache = -1;
}
setDevice(lastDevice);
assert(ctxs.top() == this);
ctxs.pop();
if (cacheDisabled()) {
// ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset
currentDeviceCache = -1;
}
}
const bool disableCache;
int lastDevice;
public:
static int getDevice() {
int idx = -1;
if (cacheDisabled() || currentDeviceCache < 0) {
checkCUDA(cudaGetDevice(&idx));
} else {
idx = currentDeviceCache;
}
currentDeviceCache = cacheDisabled() ? -1 : idx;
return idx;
}
private:
static void setDevice(int idx) {
// TODO: deal with stream when switching device
assert(idx >= 0);
if (!cacheDisabled() && currentDeviceCache == idx) {
return;
}
checkCUDA(cudaSetDevice(idx));
currentDeviceCache = cacheDisabled() ? -1 : idx;
}
private:
static inline thread_local std::stack<CUDADeviceContext *> ctxs;
static inline thread_local int currentDeviceCache = -1;
static bool cacheDisabled() {
return ctxs.empty() || ctxs.top()->disableCache;
}
};
inline cudaDeviceProp *getCurrentDeviceProperties() { inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local cudaDeviceProp prop; static thread_local std::map<int, cudaDeviceProp> props;
static thread_local bool propAvailable = false;
if (!propAvailable) { int deviceId = CUDADeviceContext::getDevice();
int device; if (!props.contains(deviceId)) {
checkCUDA(cudaGetDevice(&device)); cudaDeviceProp prop;
checkCUDA(cudaGetDeviceProperties(&prop, device)); checkCUDA(cudaGetDeviceProperties(&prop, deviceId));
propAvailable = true; props[deviceId] = prop;
} }
return &prop; return &props.at(deviceId);
} }
template<typename T> template<typename T>
......
...@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) { ...@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) {
} }
static const std::map<at::ScalarType, Tensor::ScalarType> mapType = { static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
{ at::ScalarType::Char, Tensor::INT8 },
{ at::ScalarType::Byte, Tensor::INT8 }, { at::ScalarType::Byte, Tensor::INT8 },
{ at::ScalarType::Int, Tensor::INT32 }, { at::ScalarType::Int, Tensor::INT32 },
{ at::ScalarType::Long, Tensor::INT64 }, { at::ScalarType::Long, Tensor::INT64 },
...@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) { ...@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result.scalarType = mapType.at(input.scalar_type()); result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input)); result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
// Tensor::lockBuffer(result.buffer, getCurrentCUDAStream()); Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return result; return result;
} }
......
...@@ -15,7 +15,7 @@ public: ...@@ -15,7 +15,7 @@ public:
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory // TODO: figure out how torch manages memory
return true; return this->device.type == Device::CUDA;
} }
private: private:
at::Tensor tensor; at::Tensor tensor;
...@@ -30,4 +30,22 @@ public: ...@@ -30,4 +30,22 @@ public:
}; };
Tensor from_torch(at::Tensor input); Tensor from_torch(at::Tensor input);
at::Tensor to_torch(Tensor input); at::Tensor to_torch(Tensor input);
\ No newline at end of file
class TensorsProviderTorch : public TensorsProvider {
public:
TensorsProviderTorch(std::map<std::string, at::Tensor> dict) : storage(std::move(dict)) {}
virtual bool contains(const std::string &key) const override {
return storage.contains(key);
}
virtual Tensor getTensor(const std::string &key) override {
if (!storage.contains(key)) {
return Tensor{};
}
return from_torch(storage.at(key));
}
private:
std::map<std::string, at::Tensor> storage;
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment