Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
27232e7b
Commit
27232e7b
authored
Mar 10, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
[major] add setDevice & load weights from pytorch
parent
0b1891cd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
164 additions
and
15 deletions
+164
-15
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+11
-1
nunchaku/csrc/module.h
nunchaku/csrc/module.h
+23
-0
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+8
-0
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+6
-1
src/Tensor.h
src/Tensor.h
+4
-1
src/common.h
src/common.h
+90
-9
src/interop/torch.cpp
src/interop/torch.cpp
+2
-1
src/interop/torch.h
src/interop/torch.h
+20
-2
No files found.
nunchaku/csrc/flux.h
View file @
27232e7b
...
@@ -10,10 +10,13 @@
...
@@ -10,10 +10,13 @@
class
QuantizedFluxModel
:
public
ModuleWrapper
<
FluxModel
>
{
// : public torch::CustomClassHolder {
class
QuantizedFluxModel
:
public
ModuleWrapper
<
FluxModel
>
{
// : public torch::CustomClassHolder {
public:
public:
void
init
(
bool
use_fp4
,
bool
offload
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
bool
use_fp4
,
bool
offload
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedFluxModel
"
);
spdlog
::
info
(
"Initializing QuantizedFluxModel
on device {}"
,
deviceId
);
if
(
offload
)
{
if
(
offload
)
{
spdlog
::
info
(
"Layer offloading enabled"
);
spdlog
::
info
(
"Layer offloading enabled"
);
}
}
ModuleWrapper
::
init
(
deviceId
);
CUDADeviceContext
ctx
(
this
->
deviceId
);
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
...
@@ -27,6 +30,7 @@ public:
...
@@ -27,6 +30,7 @@ public:
bool
skip_first_layer
=
false
)
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward"
);
spdlog
::
debug
(
"QuantizedFluxModel forward"
);
...
@@ -61,6 +65,8 @@ public:
...
@@ -61,6 +65,8 @@ public:
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
)
torch
::
Tensor
rotary_emb_context
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedFluxModel forward_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
...
@@ -91,6 +97,8 @@ public:
...
@@ -91,6 +97,8 @@ public:
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_single
)
torch
::
Tensor
rotary_emb_single
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_single_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedFluxModel forward_single_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
...
@@ -117,6 +125,8 @@ public:
...
@@ -117,6 +125,8 @@ public:
throw
std
::
invalid_argument
(
"skipRanks must be multiples of 16"
);
throw
std
::
invalid_argument
(
"skipRanks must be multiples of 16"
);
}
}
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
info
(
"Set lora scale to {} (skip {} ranks)"
,
scale
,
skipRanks
);
spdlog
::
info
(
"Set lora scale to {} (skip {} ranks)"
,
scale
,
skipRanks
);
net
->
traverse
([
&
](
Module
*
module
)
{
net
->
traverse
([
&
](
Module
*
module
)
{
...
...
nunchaku/csrc/module.h
View file @
27232e7b
...
@@ -9,7 +9,12 @@
...
@@ -9,7 +9,12 @@
template
<
typename
M
>
template
<
typename
M
>
class
ModuleWrapper
{
class
ModuleWrapper
{
public:
public:
void
init
(
int
deviceId
)
{
this
->
deviceId
=
deviceId
;
}
void
reset
()
{
void
reset
()
{
CUDADeviceContext
ctx
(
this
->
deviceId
);
debugContext
.
reset
();
debugContext
.
reset
();
net
.
reset
();
net
.
reset
();
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
...
@@ -20,6 +25,7 @@ public:
...
@@ -20,6 +25,7 @@ public:
void
load
(
std
::
string
path
,
bool
partial
=
false
)
{
void
load
(
std
::
string
path
,
bool
partial
=
false
)
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
this
->
deviceId
);
spdlog
::
info
(
"{} weights from {}"
,
partial
?
"Loading partial"
:
"Loading"
,
path
);
spdlog
::
info
(
"{} weights from {}"
,
partial
?
"Loading partial"
:
"Loading"
,
path
);
...
@@ -30,6 +36,19 @@ public:
...
@@ -30,6 +36,19 @@ public:
spdlog
::
info
(
"Done."
);
spdlog
::
info
(
"Done."
);
}
}
void
loadDict
(
std
::
map
<
std
::
string
,
torch
::
Tensor
>
dict
,
bool
partial
=
false
)
{
checkModel
();
CUDADeviceContext
ctx
(
this
->
deviceId
);
spdlog
::
info
(
"{} weights from pytorch"
,
partial
?
"Loading partial"
:
"Loading"
);
std
::
shared_ptr
<
TensorsProviderTorch
>
provider
=
std
::
make_shared
<
TensorsProviderTorch
>
(
std
::
move
(
dict
));
net
->
loadParams
(
*
provider
,
partial
);
Tensor
::
synchronizeDevice
();
spdlog
::
info
(
"Done."
);
}
void
startDebug
()
{
void
startDebug
()
{
debugContext
=
std
::
make_unique
<
DebugContext
>
();
debugContext
=
std
::
make_unique
<
DebugContext
>
();
}
}
...
@@ -38,6 +57,8 @@ public:
...
@@ -38,6 +57,8 @@ public:
}
}
auto
getDebugResults
()
{
auto
getDebugResults
()
{
CUDADeviceContext
ctx
(
this
->
deviceId
);
std
::
map
<
std
::
string
,
torch
::
Tensor
>
result
;
std
::
map
<
std
::
string
,
torch
::
Tensor
>
result
;
if
(
debugContext
)
{
if
(
debugContext
)
{
...
@@ -59,4 +80,6 @@ protected:
...
@@ -59,4 +80,6 @@ protected:
protected:
protected:
std
::
unique_ptr
<
M
>
net
;
std
::
unique_ptr
<
M
>
net
;
std
::
unique_ptr
<
DebugContext
>
debugContext
;
std
::
unique_ptr
<
DebugContext
>
debugContext
;
int
deviceId
=
-
1
;
};
};
\ No newline at end of file
nunchaku/csrc/pybind.cpp
View file @
27232e7b
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"path"
),
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
)
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
)
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
...
@@ -45,6 +49,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -45,6 +49,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"path"
),
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
py
::
arg
(
"partial"
)
=
false
)
)
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedSanaModel
::
forward
)
.
def
(
"forward"
,
&
QuantizedSanaModel
::
forward
)
.
def
(
"forward_layer"
,
&
QuantizedSanaModel
::
forward_layer
)
.
def
(
"forward_layer"
,
&
QuantizedSanaModel
::
forward_layer
)
.
def
(
"startDebug"
,
&
QuantizedSanaModel
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedSanaModel
::
startDebug
)
...
...
nunchaku/csrc/sana.h
View file @
27232e7b
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
public:
public:
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedSanaModel
"
);
spdlog
::
info
(
"Initializing QuantizedSanaModel
on device {}"
,
deviceId
);
SanaConfig
cfg
{
SanaConfig
cfg
{
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
...
@@ -19,6 +19,9 @@ public:
...
@@ -19,6 +19,9 @@ public:
.
pag_layers
=
pag_layers
,
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
.
use_fp4
=
use_fp4
,
};
};
ModuleWrapper
::
init
(
deviceId
);
CUDADeviceContext
ctx
(
this
->
deviceId
);
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
...
@@ -34,6 +37,7 @@ public:
...
@@ -34,6 +37,7 @@ public:
bool
cfg
)
bool
cfg
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward"
);
spdlog
::
debug
(
"QuantizedSanaModel forward"
);
...
@@ -72,6 +76,7 @@ public:
...
@@ -72,6 +76,7 @@ public:
bool
cfg
)
bool
cfg
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedSanaModel forward_layer {}"
,
idx
);
...
...
src/Tensor.h
View file @
27232e7b
...
@@ -81,7 +81,8 @@ public:
...
@@ -81,7 +81,8 @@ public:
BufferCUDA
(
size_t
size
)
{
BufferCUDA
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CUDA
;
this
->
device
.
type
=
Device
::
CUDA
;
checkCUDA
(
cudaGetDevice
(
&
this
->
device
.
idx
));
// checkCUDA(cudaGetDevice(&this->device.idx));
this
->
device
.
idx
=
CUDADeviceContext
::
getDevice
();
if
(
size
==
0
)
{
if
(
size
==
0
)
{
this
->
ptr
=
nullptr
;
this
->
ptr
=
nullptr
;
}
}
...
@@ -418,6 +419,7 @@ public:
...
@@ -418,6 +419,7 @@ public:
result
.
buffer
=
std
::
make_shared
<
BufferMalloc
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
result
.
buffer
=
std
::
make_shared
<
BufferMalloc
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
// TODO: cross device allocate
// TODO: cross device allocate
CUDADeviceContext
ctx
(
device
.
idx
);
result
.
buffer
=
std
::
make_shared
<
BufferCUDA
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
result
.
buffer
=
std
::
make_shared
<
BufferCUDA
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
}
else
{
}
else
{
assert
(
false
);
assert
(
false
);
...
@@ -429,6 +431,7 @@ public:
...
@@ -429,6 +431,7 @@ public:
if
(
device
.
type
==
Device
::
CPU
)
{
if
(
device
.
type
==
Device
::
CPU
)
{
memset
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
());
memset
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
());
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
CUDADeviceContext
ctx
(
device
.
idx
);
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
}
}
}
}
...
...
src/common.h
View file @
27232e7b
...
@@ -107,16 +107,97 @@ struct CUDAEventWrapper {
...
@@ -107,16 +107,97 @@ struct CUDAEventWrapper {
}
}
};
};
/**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device
* 3. hold one with `disableCache` when calling external code that may change the device
*/
class
CUDADeviceContext
{
public:
CUDADeviceContext
(
int
device
=
-
1
,
bool
disableCache
=
false
)
:
disableCache
(
disableCache
)
{
if
(
cacheDisabled
())
{
// no previous context => we might entered from external code, reset cache
// previous context is reset on => external code may be executed, reset
currentDeviceCache
=
-
1
;
}
ctxs
.
push
(
this
);
lastDevice
=
getDevice
();
if
(
device
>=
0
)
{
setDevice
(
device
);
}
if
(
disableCache
)
{
// we are about to call external code, reset cache
currentDeviceCache
=
-
1
;
}
}
CUDADeviceContext
(
const
CUDADeviceContext
&
)
=
delete
;
CUDADeviceContext
(
CUDADeviceContext
&&
)
=
delete
;
~
CUDADeviceContext
()
{
if
(
disableCache
)
{
// retured from external code, cache is not reliable, reset
currentDeviceCache
=
-
1
;
}
setDevice
(
lastDevice
);
assert
(
ctxs
.
top
()
==
this
);
ctxs
.
pop
();
if
(
cacheDisabled
())
{
// ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset
currentDeviceCache
=
-
1
;
}
}
const
bool
disableCache
;
int
lastDevice
;
public:
static
int
getDevice
()
{
int
idx
=
-
1
;
if
(
cacheDisabled
()
||
currentDeviceCache
<
0
)
{
checkCUDA
(
cudaGetDevice
(
&
idx
));
}
else
{
idx
=
currentDeviceCache
;
}
currentDeviceCache
=
cacheDisabled
()
?
-
1
:
idx
;
return
idx
;
}
private:
static
void
setDevice
(
int
idx
)
{
// TODO: deal with stream when switching device
assert
(
idx
>=
0
);
if
(
!
cacheDisabled
()
&&
currentDeviceCache
==
idx
)
{
return
;
}
checkCUDA
(
cudaSetDevice
(
idx
));
currentDeviceCache
=
cacheDisabled
()
?
-
1
:
idx
;
}
private:
static
inline
thread_local
std
::
stack
<
CUDADeviceContext
*>
ctxs
;
static
inline
thread_local
int
currentDeviceCache
=
-
1
;
static
bool
cacheDisabled
()
{
return
ctxs
.
empty
()
||
ctxs
.
top
()
->
disableCache
;
}
};
inline
cudaDeviceProp
*
getCurrentDeviceProperties
()
{
inline
cudaDeviceProp
*
getCurrentDeviceProperties
()
{
static
thread_local
cudaDeviceProp
prop
;
static
thread_local
std
::
map
<
int
,
cudaDeviceProp
>
prop
s
;
static
thread_local
bool
propAvailable
=
false
;
i
f
(
!
propAvailable
)
{
i
nt
deviceId
=
CUDADeviceContext
::
getDevice
();
int
device
;
if
(
!
props
.
contains
(
deviceId
))
{
c
heckCUDA
(
cudaGetDevice
(
&
device
))
;
c
udaDeviceProp
prop
;
checkCUDA
(
cudaGetDeviceProperties
(
&
prop
,
device
));
checkCUDA
(
cudaGetDeviceProperties
(
&
prop
,
device
Id
));
prop
Available
=
true
;
prop
s
[
deviceId
]
=
prop
;
}
}
return
&
prop
;
return
&
prop
s
.
at
(
deviceId
)
;
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
src/interop/torch.cpp
View file @
27232e7b
...
@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) {
...
@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) {
}
}
static
const
std
::
map
<
at
::
ScalarType
,
Tensor
::
ScalarType
>
mapType
=
{
static
const
std
::
map
<
at
::
ScalarType
,
Tensor
::
ScalarType
>
mapType
=
{
{
at
::
ScalarType
::
Char
,
Tensor
::
INT8
},
{
at
::
ScalarType
::
Byte
,
Tensor
::
INT8
},
{
at
::
ScalarType
::
Byte
,
Tensor
::
INT8
},
{
at
::
ScalarType
::
Int
,
Tensor
::
INT32
},
{
at
::
ScalarType
::
Int
,
Tensor
::
INT32
},
{
at
::
ScalarType
::
Long
,
Tensor
::
INT64
},
{
at
::
ScalarType
::
Long
,
Tensor
::
INT64
},
...
@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
...
@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
buffer
=
std
::
make_shared
<
BufferTorchTensor
>
(
std
::
move
(
input
));
result
.
buffer
=
std
::
make_shared
<
BufferTorchTensor
>
(
std
::
move
(
input
));
//
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
Tensor
::
lockBuffer
(
result
.
buffer
,
getCurrentCUDAStream
());
return
result
;
return
result
;
}
}
...
...
src/interop/torch.h
View file @
27232e7b
...
@@ -15,7 +15,7 @@ public:
...
@@ -15,7 +15,7 @@ public:
}
}
virtual
bool
isAsyncBuffer
()
override
{
virtual
bool
isAsyncBuffer
()
override
{
// TODO: figure out how torch manages memory
// TODO: figure out how torch manages memory
return
t
rue
;
return
t
his
->
device
.
type
==
Device
::
CUDA
;
}
}
private:
private:
at
::
Tensor
tensor
;
at
::
Tensor
tensor
;
...
@@ -30,4 +30,22 @@ public:
...
@@ -30,4 +30,22 @@ public:
};
};
Tensor
from_torch
(
at
::
Tensor
input
);
Tensor
from_torch
(
at
::
Tensor
input
);
at
::
Tensor
to_torch
(
Tensor
input
);
at
::
Tensor
to_torch
(
Tensor
input
);
\ No newline at end of file
class
TensorsProviderTorch
:
public
TensorsProvider
{
public:
TensorsProviderTorch
(
std
::
map
<
std
::
string
,
at
::
Tensor
>
dict
)
:
storage
(
std
::
move
(
dict
))
{}
virtual
bool
contains
(
const
std
::
string
&
key
)
const
override
{
return
storage
.
contains
(
key
);
}
virtual
Tensor
getTensor
(
const
std
::
string
&
key
)
override
{
if
(
!
storage
.
contains
(
key
))
{
return
Tensor
{};
}
return
from_torch
(
storage
.
at
(
key
));
}
private:
std
::
map
<
std
::
string
,
at
::
Tensor
>
storage
;
};
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment