"vscode:/vscode.git/clone" did not exist on "79892d376c5faeaa55f6ccb656b1a09c0781815a"
Commit 27232e7b authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] add setDevice & load weights from pytorch

parent 0b1891cd
......@@ -10,10 +10,13 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public:
void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel");
spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
if (offload) {
spdlog::info("Layer offloading enabled");
}
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
......@@ -27,6 +30,7 @@ public:
bool skip_first_layer = false)
{
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward");
......@@ -61,6 +65,8 @@ public:
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context)
{
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
......@@ -91,6 +97,8 @@ public:
torch::Tensor temb,
torch::Tensor rotary_emb_single)
{
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous();
......@@ -117,6 +125,8 @@ public:
throw std::invalid_argument("skipRanks must be multiples of 16");
}
CUDADeviceContext ctx(deviceId);
spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
net->traverse([&](Module *module) {
......
......@@ -9,7 +9,12 @@
template<typename M>
class ModuleWrapper {
public:
void init(int deviceId) {
this->deviceId = deviceId;
}
void reset() {
CUDADeviceContext ctx(this->deviceId);
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
......@@ -20,6 +25,7 @@ public:
void load(std::string path, bool partial = false) {
checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
......@@ -30,6 +36,19 @@ public:
spdlog::info("Done.");
}
void loadDict(std::map<std::string, torch::Tensor> dict, bool partial = false) {
checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
......@@ -38,6 +57,8 @@ public:
}
auto getDebugResults() {
CUDADeviceContext ctx(this->deviceId);
std::map<std::string, torch::Tensor> result;
if (debugContext) {
......@@ -59,4 +80,6 @@ protected:
protected:
std::unique_ptr<M> net;
std::unique_ptr<DebugContext> debugContext;
int deviceId = -1;
};
\ No newline at end of file
......@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward)
.def("forward_layer", &QuantizedFluxModel::forward_layer)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
......@@ -45,6 +49,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug)
......
......@@ -9,7 +9,7 @@
class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel");
spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(),
......@@ -19,6 +19,9 @@ public:
.pag_layers = pag_layers,
.use_fp4 = use_fp4,
};
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
......@@ -34,6 +37,7 @@ public:
bool cfg)
{
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward");
......@@ -72,6 +76,7 @@ public:
bool cfg)
{
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
......
......@@ -81,7 +81,8 @@ public:
BufferCUDA(size_t size) {
this->size = size;
this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx));
// checkCUDA(cudaGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice();
if (size == 0) {
this->ptr = nullptr;
}
......@@ -418,6 +419,7 @@ public:
result.buffer = std::make_shared<BufferMalloc>(shape.size() * scalarSize.at(scalarType));
} else if (device.type == Device::CUDA) {
// TODO: cross device allocate
CUDADeviceContext ctx(device.idx);
result.buffer = std::make_shared<BufferCUDA>(shape.size() * scalarSize.at(scalarType));
} else {
assert(false);
......@@ -429,6 +431,7 @@ public:
if (device.type == Device::CPU) {
memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize());
} else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx);
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
}
}
......
......@@ -107,16 +107,97 @@ struct CUDAEventWrapper {
}
};
/**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device
* 3. hold one with `disableCache` when calling external code that may change the device
*/
class CUDADeviceContext {
public:
CUDADeviceContext(int device = -1, bool disableCache = false) : disableCache(disableCache) {
if (cacheDisabled()) {
// no previous context => we might entered from external code, reset cache
// previous context is reset on => external code may be executed, reset
currentDeviceCache = -1;
}
ctxs.push(this);
lastDevice = getDevice();
if (device >= 0) {
setDevice(device);
}
if (disableCache) {
// we are about to call external code, reset cache
currentDeviceCache = -1;
}
}
CUDADeviceContext(const CUDADeviceContext &) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete;
~CUDADeviceContext() {
if (disableCache) {
// retured from external code, cache is not reliable, reset
currentDeviceCache = -1;
}
setDevice(lastDevice);
assert(ctxs.top() == this);
ctxs.pop();
if (cacheDisabled()) {
// ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset
currentDeviceCache = -1;
}
}
const bool disableCache;
int lastDevice;
public:
static int getDevice() {
int idx = -1;
if (cacheDisabled() || currentDeviceCache < 0) {
checkCUDA(cudaGetDevice(&idx));
} else {
idx = currentDeviceCache;
}
currentDeviceCache = cacheDisabled() ? -1 : idx;
return idx;
}
private:
static void setDevice(int idx) {
// TODO: deal with stream when switching device
assert(idx >= 0);
if (!cacheDisabled() && currentDeviceCache == idx) {
return;
}
checkCUDA(cudaSetDevice(idx));
currentDeviceCache = cacheDisabled() ? -1 : idx;
}
private:
static inline thread_local std::stack<CUDADeviceContext *> ctxs;
static inline thread_local int currentDeviceCache = -1;
static bool cacheDisabled() {
return ctxs.empty() || ctxs.top()->disableCache;
}
};
inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local cudaDeviceProp prop;
static thread_local bool propAvailable = false;
if (!propAvailable) {
int device;
checkCUDA(cudaGetDevice(&device));
checkCUDA(cudaGetDeviceProperties(&prop, device));
propAvailable = true;
}
return &prop;
static thread_local std::map<int, cudaDeviceProp> props;
int deviceId = CUDADeviceContext::getDevice();
if (!props.contains(deviceId)) {
cudaDeviceProp prop;
checkCUDA(cudaGetDeviceProperties(&prop, deviceId));
props[deviceId] = prop;
}
return &props.at(deviceId);
}
template<typename T>
......
......@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) {
}
static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
{ at::ScalarType::Char, Tensor::INT8 },
{ at::ScalarType::Byte, Tensor::INT8 },
{ at::ScalarType::Int, Tensor::INT32 },
{ at::ScalarType::Long, Tensor::INT64 },
......@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
// Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return result;
}
......
......@@ -15,7 +15,7 @@ public:
}
virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory
return true;
return this->device.type == Device::CUDA;
}
private:
at::Tensor tensor;
......@@ -30,4 +30,22 @@ public:
};
Tensor from_torch(at::Tensor input);
at::Tensor to_torch(Tensor input);
\ No newline at end of file
at::Tensor to_torch(Tensor input);
class TensorsProviderTorch : public TensorsProvider {
public:
TensorsProviderTorch(std::map<std::string, at::Tensor> dict) : storage(std::move(dict)) {}
virtual bool contains(const std::string &key) const override {
return storage.contains(key);
}
virtual Tensor getTensor(const std::string &key) override {
if (!storage.contains(key)) {
return Tensor{};
}
return from_torch(storage.at(key));
}
private:
std::map<std::string, at::Tensor> storage;
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment