Commit 0c1c2d4a authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] alternative method to load safetensors

parent 65348e71
......@@ -3,28 +3,105 @@
#include <nlohmann/json.hpp>
#include <mio/mmap.hpp>
// #include <sys/mman.h>
using json = nlohmann::json;
using spdlog::fmt_lib::format;
class SafeTensors::mmap_file : public mio::mmap_source {
class SafeTensors::MMapImpl {
public:
mmap_file(std::string_view filename) : mio::mmap_source(filename, 0, mio::map_entire_file) {}
virtual ~MMapImpl() {}
virtual size_t size() = 0;
virtual const char *data() = 0;
};
SafeTensors::SafeTensors(std::string_view filename) {
std::error_code ec;
this->mapped = std::make_unique<mmap_file>(filename);
if (ec) {
throw std::system_error(ec);
class SafeTensors::MMapImplMio : public SafeTensors::MMapImpl {
public:
MMapImplMio(const std::string &filename) : impl(filename, 0, mio::map_entire_file) {}
virtual size_t size() override {
return impl.size();
}
virtual const char *data() override {
return impl.data();
}
private:
mio::mmap_source impl;
};
#ifdef __linux__
#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
class SafeTensors::MMapImplPrivate : public SafeTensors::MMapImpl {
public:
MMapImplPrivate(const std::string &filename) {
int fd = open(filename.c_str(), O_RDONLY);
if (fd < 0) {
throw std::system_error(errno, std::generic_category(), filename);
}
struct stat statbuf;
fstat(fd, &statbuf);
filesize = statbuf.st_size;
ptr = mmap(0, filesize, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, 0);
if (ptr == MAP_FAILED) {
close(fd);
throw std::system_error(errno, std::generic_category(), filename);
}
close(fd);
}
~MMapImplPrivate() {
munmap(ptr, filesize);
}
virtual size_t size() override {
return filesize;
}
virtual const char *data() override {
return (const char *)ptr;
}
private:
size_t filesize;
void *ptr;
};
#else
class SafeTensors::MMapImplPrivate : public SafeTensors::MMapImpl {
public:
MMapImplPrivate(const std::string &filename) {
throw std::runtime_error("MAP_PRIVATE is not implemented on this system")
}
// char *ptr = (char *)malloc(1024);
// checkCUDA(cudaHostRegister(ptr, 1024, cudaHostRegisterDefault));
virtual size_t size() override {
return 0;
}
virtual const char *data() override {
return nullptr;
}
};
#endif
SafeTensors::SafeTensors(const std::string &filename) {
this->mapped = std::make_unique<MMapImplMio>(filename);
if (cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly) != cudaSuccess) {
spdlog::warn("Unable to pin memory: {}", cudaGetErrorString(cudaGetLastError()));
// mlock(const_cast<char *>(this->mapped->data()), this->mapped->size());
#ifdef __linux__
spdlog::info("Try MAP_PRIVATE");
this->mapped.reset();
this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
#endif
}
parseHeader();
}
......@@ -52,7 +129,7 @@ void SafeTensors::parseHeader() {
uint64_t sizeHeader = *reinterpret_cast<const uint64_t *>(this->mapped->data());
check(this->mapped->size() - 8 >= sizeHeader);
json header = json::parse(this->mapped->begin() + 8, this->mapped->begin() + 8 + sizeHeader);
json header = json::parse(this->mapped->data() + 8, this->mapped->data() + 8 + sizeHeader);
const uint64_t offsetMax = this->mapped->size() - sizeHeader - 8;
std::set<size_t> offsets;
......
......@@ -29,7 +29,7 @@ public:
class SafeTensors : public TensorsProvider, public std::enable_shared_from_this<SafeTensors> {
public:
SafeTensors(std::string_view filename);
SafeTensors(const std::string &filename);
~SafeTensors();
virtual bool contains(const std::string &key) const override {
......@@ -41,7 +41,10 @@ private:
void parseHeader();
private:
class mmap_file;
class MMapImpl;
class MMapImplMio;
class MMapImplPrivate;
struct TensorInfo {
TensorShape shape;
Tensor::ScalarType type;
......@@ -50,5 +53,5 @@ private:
std::weak_ptr<BufferMMap> buffer;
};
std::map<std::string, TensorInfo> tensors;
std::unique_ptr<mmap_file> mapped;
std::unique_ptr<MMapImpl> mapped;
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment