Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
0c1c2d4a
Commit
0c1c2d4a
authored
Nov 09, 2024
by
sxtyzhangzk
Committed by
Zhekai Zhang
Nov 09, 2024
Browse files
[major] alternative method to load safetensors
parent
65348e71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
94 additions
and
14 deletions
+94
-14
src/Serialization.cpp
src/Serialization.cpp
+88
-11
src/Serialization.h
src/Serialization.h
+6
-3
No files found.
src/Serialization.cpp
View file @
0c1c2d4a
...
...
@@ -3,28 +3,105 @@
#include <nlohmann/json.hpp>
#include <mio/mmap.hpp>
// #include <sys/mman.h>
using
json
=
nlohmann
::
json
;
using
spdlog
::
fmt_lib
::
format
;
class
SafeTensors
::
mmap_file
:
public
mio
::
mmap_source
{
class
SafeTensors
::
MMapImpl
{
public:
mmap_file
(
std
::
string_view
filename
)
:
mio
::
mmap_source
(
filename
,
0
,
mio
::
map_entire_file
)
{}
virtual
~
MMapImpl
()
{}
virtual
size_t
size
()
=
0
;
virtual
const
char
*
data
()
=
0
;
};
SafeTensors
::
SafeTensors
(
std
::
string_view
filename
)
{
std
::
error_code
ec
;
this
->
mapped
=
std
::
make_unique
<
mmap_file
>
(
filename
);
if
(
ec
)
{
throw
std
::
system_error
(
ec
);
class
SafeTensors
::
MMapImplMio
:
public
SafeTensors
::
MMapImpl
{
public:
MMapImplMio
(
const
std
::
string
&
filename
)
:
impl
(
filename
,
0
,
mio
::
map_entire_file
)
{}
virtual
size_t
size
()
override
{
return
impl
.
size
();
}
virtual
const
char
*
data
()
override
{
return
impl
.
data
();
}
private:
mio
::
mmap_source
impl
;
};
#ifdef __linux__
#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
class
SafeTensors
::
MMapImplPrivate
:
public
SafeTensors
::
MMapImpl
{
public:
MMapImplPrivate
(
const
std
::
string
&
filename
)
{
int
fd
=
open
(
filename
.
c_str
(),
O_RDONLY
);
if
(
fd
<
0
)
{
throw
std
::
system_error
(
errno
,
std
::
generic_category
(),
filename
);
}
struct
stat
statbuf
;
fstat
(
fd
,
&
statbuf
);
filesize
=
statbuf
.
st_size
;
ptr
=
mmap
(
0
,
filesize
,
PROT_READ
|
PROT_WRITE
,
MAP_PRIVATE
,
fd
,
0
);
if
(
ptr
==
MAP_FAILED
)
{
close
(
fd
);
throw
std
::
system_error
(
errno
,
std
::
generic_category
(),
filename
);
}
close
(
fd
);
}
~
MMapImplPrivate
()
{
munmap
(
ptr
,
filesize
);
}
virtual
size_t
size
()
override
{
return
filesize
;
}
virtual
const
char
*
data
()
override
{
return
(
const
char
*
)
ptr
;
}
private:
size_t
filesize
;
void
*
ptr
;
};
#else
class
SafeTensors
::
MMapImplPrivate
:
public
SafeTensors
::
MMapImpl
{
public:
MMapImplPrivate
(
const
std
::
string
&
filename
)
{
throw
std
::
runtime_error
(
"MAP_PRIVATE is not implemented on this system"
)
}
// char *ptr = (char *)malloc(1024);
// checkCUDA(cudaHostRegister(ptr, 1024, cudaHostRegisterDefault));
virtual
size_t
size
()
override
{
return
0
;
}
virtual
const
char
*
data
()
override
{
return
nullptr
;
}
};
#endif
SafeTensors
::
SafeTensors
(
const
std
::
string
&
filename
)
{
this
->
mapped
=
std
::
make_unique
<
MMapImplMio
>
(
filename
);
if
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
|
cudaHostRegisterReadOnly
)
!=
cudaSuccess
)
{
spdlog
::
warn
(
"Unable to pin memory: {}"
,
cudaGetErrorString
(
cudaGetLastError
()));
// mlock(const_cast<char *>(this->mapped->data()), this->mapped->size());
#ifdef __linux__
spdlog
::
info
(
"Try MAP_PRIVATE"
);
this
->
mapped
.
reset
();
this
->
mapped
=
std
::
make_unique
<
MMapImplPrivate
>
(
filename
);
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
));
#endif
}
parseHeader
();
}
...
...
@@ -52,7 +129,7 @@ void SafeTensors::parseHeader() {
uint64_t
sizeHeader
=
*
reinterpret_cast
<
const
uint64_t
*>
(
this
->
mapped
->
data
());
check
(
this
->
mapped
->
size
()
-
8
>=
sizeHeader
);
json
header
=
json
::
parse
(
this
->
mapped
->
begin
()
+
8
,
this
->
mapped
->
begin
()
+
8
+
sizeHeader
);
json
header
=
json
::
parse
(
this
->
mapped
->
data
()
+
8
,
this
->
mapped
->
data
()
+
8
+
sizeHeader
);
const
uint64_t
offsetMax
=
this
->
mapped
->
size
()
-
sizeHeader
-
8
;
std
::
set
<
size_t
>
offsets
;
...
...
src/Serialization.h
View file @
0c1c2d4a
...
...
@@ -29,7 +29,7 @@ public:
class
SafeTensors
:
public
TensorsProvider
,
public
std
::
enable_shared_from_this
<
SafeTensors
>
{
public:
SafeTensors
(
std
::
string
_view
filename
);
SafeTensors
(
const
std
::
string
&
filename
);
~
SafeTensors
();
virtual
bool
contains
(
const
std
::
string
&
key
)
const
override
{
...
...
@@ -41,7 +41,10 @@ private:
void
parseHeader
();
private:
class
mmap_file
;
class
MMapImpl
;
class
MMapImplMio
;
class
MMapImplPrivate
;
struct
TensorInfo
{
TensorShape
shape
;
Tensor
::
ScalarType
type
;
...
...
@@ -50,5 +53,5 @@ private:
std
::
weak_ptr
<
BufferMMap
>
buffer
;
};
std
::
map
<
std
::
string
,
TensorInfo
>
tensors
;
std
::
unique_ptr
<
mmap_file
>
mapped
;
std
::
unique_ptr
<
MMapImpl
>
mapped
;
};
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment