Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yongshk
candle
Commits
25d2752f
Commit
25d2752f
authored
May 29, 2025
by
yongshk
Browse files
Initial commit
parents
Changes
238
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
3058 additions
and
0 deletions
+3058
-0
candle-examples/examples/reinforcement-learning/atari_wrappers.py
...xamples/examples/reinforcement-learning/atari_wrappers.py
+308
-0
candle-examples/examples/reinforcement-learning/ddpg.rs
candle-examples/examples/reinforcement-learning/ddpg.rs
+556
-0
candle-examples/examples/reinforcement-learning/dqn.rs
candle-examples/examples/reinforcement-learning/dqn.rs
+118
-0
candle-examples/examples/reinforcement-learning/gym_env.rs
candle-examples/examples/reinforcement-learning/gym_env.rs
+114
-0
candle-examples/examples/reinforcement-learning/main.rs
candle-examples/examples/reinforcement-learning/main.rs
+40
-0
candle-examples/examples/reinforcement-learning/policy_gradient.rs
...amples/examples/reinforcement-learning/policy_gradient.rs
+146
-0
candle-examples/examples/reinforcement-learning/vec_gym_env.rs
...e-examples/examples/reinforcement-learning/vec_gym_env.rs
+91
-0
candle-examples/examples/replit-code/README.md
candle-examples/examples/replit-code/README.md
+40
-0
candle-examples/examples/replit-code/main.rs
candle-examples/examples/replit-code/main.rs
+264
-0
candle-examples/examples/repvgg/README.md
candle-examples/examples/repvgg/README.md
+22
-0
candle-examples/examples/repvgg/main.rs
candle-examples/examples/repvgg/main.rs
+111
-0
candle-examples/examples/resnet/README.md
candle-examples/examples/resnet/README.md
+19
-0
candle-examples/examples/resnet/export_models.py
candle-examples/examples/resnet/export_models.py
+12
-0
candle-examples/examples/resnet/main.rs
candle-examples/examples/resnet/main.rs
+90
-0
candle-examples/examples/rwkv/README.md
candle-examples/examples/rwkv/README.md
+17
-0
candle-examples/examples/rwkv/main.rs
candle-examples/examples/rwkv/main.rs
+330
-0
candle-examples/examples/segformer/README.md
candle-examples/examples/segformer/README.md
+28
-0
candle-examples/examples/segformer/assets/labels.json
candle-examples/examples/segformer/assets/labels.json
+752
-0
No files found.
Too many changes to show.
To preserve performance only
238 of 238+
files are displayed.
Plain diff
Email patch
candle-examples/examples/reinforcement-learning/atari_wrappers.py
0 → 100644
View file @
25d2752f
import
gymnasium
as
gym
import
numpy
as
np
from
collections
import
deque
from
PIL
import
Image
from
multiprocessing
import
Process
,
Pipe
# atari_wrappers.py
class
NoopResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
noop_max
=
30
):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
noop_max
=
noop_max
self
.
override_num_noops
=
None
assert
env
.
unwrapped
.
get_action_meanings
()[
0
]
==
'NOOP'
def
reset
(
self
):
""" Do no-op action for a number of steps in [1, noop_max]."""
self
.
env
.
reset
()
if
self
.
override_num_noops
is
not
None
:
noops
=
self
.
override_num_noops
else
:
noops
=
self
.
unwrapped
.
np_random
.
integers
(
1
,
self
.
noop_max
+
1
)
#pylint: disable=E1101
assert
noops
>
0
obs
=
None
for
_
in
range
(
noops
):
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
0
)
if
done
:
obs
=
self
.
env
.
reset
()
return
obs
class
FireResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Take action on reset for environments that are fixed until firing."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
assert
env
.
unwrapped
.
get_action_meanings
()[
1
]
==
'FIRE'
assert
len
(
env
.
unwrapped
.
get_action_meanings
())
>=
3
def
reset
(
self
):
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
1
)
if
done
:
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
2
)
if
done
:
self
.
env
.
reset
()
return
obs
class
ImageSaver
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
img_path
,
rank
):
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
_cnt
=
0
self
.
_img_path
=
img_path
self
.
_rank
=
rank
def
step
(
self
,
action
):
step_result
=
self
.
env
.
step
(
action
)
obs
,
_
,
_
,
_
=
step_result
img
=
Image
.
fromarray
(
obs
,
'RGB'
)
img
.
save
(
'%s/out%d-%05d.png'
%
(
self
.
_img_path
,
self
.
_rank
,
self
.
_cnt
))
self
.
_cnt
+=
1
return
step_result
class
EpisodicLifeEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
lives
=
0
self
.
was_real_done
=
True
def
step
(
self
,
action
):
obs
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
was_real_done
=
done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives
=
self
.
env
.
unwrapped
.
ale
.
lives
()
if
lives
<
self
.
lives
and
lives
>
0
:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done
=
True
self
.
lives
=
lives
return
obs
,
reward
,
done
,
info
def
reset
(
self
):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if
self
.
was_real_done
:
obs
=
self
.
env
.
reset
()
else
:
# no-op step to advance from terminal/lost life state
obs
,
_
,
_
,
_
=
self
.
env
.
step
(
0
)
self
.
lives
=
self
.
env
.
unwrapped
.
ale
.
lives
()
return
obs
class
MaxAndSkipEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
skip
=
4
):
"""Return only every `skip`-th frame"""
gym
.
Wrapper
.
__init__
(
self
,
env
)
# most recent raw observations (for max pooling across time steps)
self
.
_obs_buffer
=
deque
(
maxlen
=
2
)
self
.
_skip
=
skip
def
step
(
self
,
action
):
"""Repeat action, sum reward, and max over last observations."""
total_reward
=
0.0
done
=
None
for
_
in
range
(
self
.
_skip
):
obs
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
_obs_buffer
.
append
(
obs
)
total_reward
+=
reward
if
done
:
break
max_frame
=
np
.
max
(
np
.
stack
(
self
.
_obs_buffer
),
axis
=
0
)
return
max_frame
,
total_reward
,
done
,
info
def
reset
(
self
):
"""Clear past frame buffer and init. to first obs. from inner env."""
self
.
_obs_buffer
.
clear
()
obs
=
self
.
env
.
reset
()
self
.
_obs_buffer
.
append
(
obs
)
return
obs
class
ClipRewardEnv
(
gym
.
RewardWrapper
):
def
reward
(
self
,
reward
):
"""Bin reward to {+1, 0, -1} by its sign."""
return
np
.
sign
(
reward
)
class
WarpFrame
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
):
"""Warp frames to 84x84 as done in the Nature paper and later work."""
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
res
=
84
self
.
observation_space
=
gym
.
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
self
.
res
,
self
.
res
,
1
),
dtype
=
'uint8'
)
def
observation
(
self
,
obs
):
frame
=
np
.
dot
(
obs
.
astype
(
'float32'
),
np
.
array
([
0.299
,
0.587
,
0.114
],
'float32'
))
frame
=
np
.
array
(
Image
.
fromarray
(
frame
).
resize
((
self
.
res
,
self
.
res
),
resample
=
Image
.
BILINEAR
),
dtype
=
np
.
uint8
)
return
frame
.
reshape
((
self
.
res
,
self
.
res
,
1
))
class
FrameStack
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
"""Buffer observations and stack across channels (last axis)."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
assert
shp
[
2
]
==
1
# can only stack 1-channel frames
self
.
observation_space
=
gym
.
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
k
),
dtype
=
'uint8'
)
def
reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
ob
=
self
.
env
.
reset
()
for
_
in
range
(
self
.
k
):
self
.
frames
.
append
(
ob
)
return
self
.
observation
()
def
step
(
self
,
action
):
ob
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
frames
.
append
(
ob
)
return
self
.
observation
(),
reward
,
done
,
info
def
observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
return
np
.
concatenate
(
self
.
frames
,
axis
=
2
)
def
wrap_deepmind
(
env
,
episode_life
=
True
,
clip_rewards
=
True
):
"""Configure environment for DeepMind-style Atari.
Note: this does not include frame stacking!"""
assert
'NoFrameskip'
in
env
.
spec
.
id
# required for DeepMind-style skip
if
episode_life
:
env
=
EpisodicLifeEnv
(
env
)
env
=
NoopResetEnv
(
env
,
noop_max
=
30
)
env
=
MaxAndSkipEnv
(
env
,
skip
=
4
)
if
'FIRE'
in
env
.
unwrapped
.
get_action_meanings
():
env
=
FireResetEnv
(
env
)
env
=
WarpFrame
(
env
)
if
clip_rewards
:
env
=
ClipRewardEnv
(
env
)
return
env
# envs.py
def
make_env
(
env_id
,
img_dir
,
seed
,
rank
):
def
_thunk
():
env
=
gym
.
make
(
env_id
)
env
.
reset
(
seed
=
(
seed
+
rank
))
if
img_dir
is
not
None
:
env
=
ImageSaver
(
env
,
img_dir
,
rank
)
env
=
wrap_deepmind
(
env
)
env
=
WrapPyTorch
(
env
)
return
env
return
_thunk
class
WrapPyTorch
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
=
None
):
super
(
WrapPyTorch
,
self
).
__init__
(
env
)
self
.
observation_space
=
gym
.
spaces
.
Box
(
0.0
,
1.0
,
[
1
,
84
,
84
],
dtype
=
'float32'
)
def
observation
(
self
,
observation
):
return
observation
.
transpose
(
2
,
0
,
1
)
# vecenv.py
class
VecEnv
(
object
):
"""
Vectorized environment base class
"""
def
step
(
self
,
vac
):
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, news)
where 'news' is a boolean vector indicating whether each element is new.
"""
raise
NotImplementedError
def
reset
(
self
):
"""
Reset all environments
"""
raise
NotImplementedError
def
close
(
self
):
pass
# subproc_vec_env.py
def
worker
(
remote
,
env_fn_wrapper
):
env
=
env_fn_wrapper
.
x
()
while
True
:
cmd
,
data
=
remote
.
recv
()
if
cmd
==
'step'
:
ob
,
reward
,
done
,
info
=
env
.
step
(
data
)
if
done
:
ob
=
env
.
reset
()
remote
.
send
((
ob
,
reward
,
done
,
info
))
elif
cmd
==
'reset'
:
ob
=
env
.
reset
()
remote
.
send
(
ob
)
elif
cmd
==
'close'
:
remote
.
close
()
break
elif
cmd
==
'get_spaces'
:
remote
.
send
((
env
.
action_space
,
env
.
observation_space
))
else
:
raise
NotImplementedError
class
CloudpickleWrapper
(
object
):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def
__init__
(
self
,
x
):
self
.
x
=
x
def
__getstate__
(
self
):
import
cloudpickle
return
cloudpickle
.
dumps
(
self
.
x
)
def
__setstate__
(
self
,
ob
):
import
pickle
self
.
x
=
pickle
.
loads
(
ob
)
class
SubprocVecEnv
(
VecEnv
):
def
__init__
(
self
,
env_fns
):
"""
envs: list of gym environments to run in subprocesses
"""
nenvs
=
len
(
env_fns
)
self
.
remotes
,
self
.
work_remotes
=
zip
(
*
[
Pipe
()
for
_
in
range
(
nenvs
)])
self
.
ps
=
[
Process
(
target
=
worker
,
args
=
(
work_remote
,
CloudpickleWrapper
(
env_fn
)))
for
(
work_remote
,
env_fn
)
in
zip
(
self
.
work_remotes
,
env_fns
)]
for
p
in
self
.
ps
:
p
.
start
()
self
.
remotes
[
0
].
send
((
'get_spaces'
,
None
))
self
.
action_space
,
self
.
observation_space
=
self
.
remotes
[
0
].
recv
()
def
step
(
self
,
actions
):
for
remote
,
action
in
zip
(
self
.
remotes
,
actions
):
remote
.
send
((
'step'
,
action
))
results
=
[
remote
.
recv
()
for
remote
in
self
.
remotes
]
obs
,
rews
,
dones
,
infos
=
zip
(
*
results
)
return
np
.
stack
(
obs
),
np
.
stack
(
rews
),
np
.
stack
(
dones
),
infos
def
reset
(
self
):
for
remote
in
self
.
remotes
:
remote
.
send
((
'reset'
,
None
))
return
np
.
stack
([
remote
.
recv
()
for
remote
in
self
.
remotes
])
def
close
(
self
):
for
remote
in
self
.
remotes
:
remote
.
send
((
'close'
,
None
))
for
p
in
self
.
ps
:
p
.
join
()
@
property
def
num_envs
(
self
):
return
len
(
self
.
remotes
)
# Create the environment.
def
make
(
env_name
,
img_dir
,
num_processes
):
envs
=
SubprocVecEnv
([
make_env
(
env_name
,
img_dir
,
1337
,
i
)
for
i
in
range
(
num_processes
)
])
return
envs
candle-examples/examples/reinforcement-learning/ddpg.rs
0 → 100644
View file @
25d2752f
use
std
::
collections
::
VecDeque
;
use
std
::
fmt
::
Display
;
use
candle
::{
DType
,
Device
,
Error
,
Module
,
Result
,
Tensor
,
Var
};
use
candle_nn
::{
func
,
linear
,
sequential
::
seq
,
Activation
,
AdamW
,
Optimizer
,
ParamsAdamW
,
Sequential
,
VarBuilder
,
VarMap
,
};
use
rand
::{
distributions
::
Uniform
,
thread_rng
,
Rng
};
use
super
::
gym_env
::
GymEnv
;
pub
struct
OuNoise
{
mu
:
f64
,
theta
:
f64
,
sigma
:
f64
,
state
:
Tensor
,
}
impl
OuNoise
{
pub
fn
new
(
mu
:
f64
,
theta
:
f64
,
sigma
:
f64
,
size_action
:
usize
)
->
Result
<
Self
>
{
Ok
(
Self
{
mu
,
theta
,
sigma
,
state
:
Tensor
::
ones
(
size_action
,
DType
::
F32
,
&
Device
::
Cpu
)
?
,
})
}
pub
fn
sample
(
&
mut
self
)
->
Result
<
Tensor
>
{
let
rand
=
Tensor
::
randn_like
(
&
self
.state
,
0.0
,
1.0
)
?
;
let
dx
=
((
self
.theta
*
(
self
.mu
-
&
self
.state
)
?
)
?
+
(
self
.sigma
*
rand
)
?
)
?
;
self
.state
=
(
&
self
.state
+
dx
)
?
;
Ok
(
self
.state
.clone
())
}
}
#[derive(Clone)]
struct
Transition
{
state
:
Tensor
,
action
:
Tensor
,
reward
:
Tensor
,
next_state
:
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
}
impl
Transition
{
fn
new
(
state
:
&
Tensor
,
action
:
&
Tensor
,
reward
:
&
Tensor
,
next_state
:
&
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
)
->
Self
{
Self
{
state
:
state
.clone
(),
action
:
action
.clone
(),
reward
:
reward
.clone
(),
next_state
:
next_state
.clone
(),
terminated
,
truncated
,
}
}
}
pub
struct
ReplayBuffer
{
buffer
:
VecDeque
<
Transition
>
,
capacity
:
usize
,
size
:
usize
,
}
impl
ReplayBuffer
{
pub
fn
new
(
capacity
:
usize
)
->
Self
{
Self
{
buffer
:
VecDeque
::
with_capacity
(
capacity
),
capacity
,
size
:
0
,
}
}
pub
fn
push
(
&
mut
self
,
state
:
&
Tensor
,
action
:
&
Tensor
,
reward
:
&
Tensor
,
next_state
:
&
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
)
{
if
self
.size
==
self
.capacity
{
self
.buffer
.pop_front
();
}
else
{
self
.size
+=
1
;
}
self
.buffer
.push_back
(
Transition
::
new
(
state
,
action
,
reward
,
next_state
,
terminated
,
truncated
,
));
}
#[allow(clippy::type_complexity)]
pub
fn
random_batch
(
&
self
,
batch_size
:
usize
,
)
->
Result
<
Option
<
(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Vec
<
bool
>
,
Vec
<
bool
>
)
>>
{
if
self
.size
<
batch_size
{
Ok
(
None
)
}
else
{
let
transitions
:
Vec
<&
Transition
>
=
thread_rng
()
.sample_iter
(
Uniform
::
from
(
0
..
self
.size
))
.take
(
batch_size
)
.map
(|
i
|
self
.buffer
.get
(
i
)
.unwrap
())
.collect
();
let
states
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.state
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
actions
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.action
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
rewards
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.reward
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
next_states
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.next_state
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
terminateds
:
Vec
<
bool
>
=
transitions
.iter
()
.map
(|
t
|
t
.terminated
)
.collect
();
let
truncateds
:
Vec
<
bool
>
=
transitions
.iter
()
.map
(|
t
|
t
.truncated
)
.collect
();
Ok
(
Some
((
Tensor
::
cat
(
&
states
,
0
)
?
,
Tensor
::
cat
(
&
actions
,
0
)
?
,
Tensor
::
cat
(
&
rewards
,
0
)
?
,
Tensor
::
cat
(
&
next_states
,
0
)
?
,
terminateds
,
truncateds
,
)))
}
}
}
fn
track
(
varmap
:
&
mut
VarMap
,
vb
:
&
VarBuilder
,
target_prefix
:
&
str
,
network_prefix
:
&
str
,
dims
:
&
[(
usize
,
usize
)],
tau
:
f64
,
)
->
Result
<
()
>
{
for
(
i
,
&
(
in_dim
,
out_dim
))
in
dims
.iter
()
.enumerate
()
{
let
target_w
=
vb
.get
((
out_dim
,
in_dim
),
&
format!
(
"{target_prefix}-fc{i}.weight"
))
?
;
let
network_w
=
vb
.get
((
out_dim
,
in_dim
),
&
format!
(
"{network_prefix}-fc{i}.weight"
))
?
;
varmap
.set_one
(
format!
(
"{target_prefix}-fc{i}.weight"
),
((
tau
*
network_w
)
?
+
((
1.0
-
tau
)
*
target_w
)
?
)
?
,
)
?
;
let
target_b
=
vb
.get
(
out_dim
,
&
format!
(
"{target_prefix}-fc{i}.bias"
))
?
;
let
network_b
=
vb
.get
(
out_dim
,
&
format!
(
"{network_prefix}-fc{i}.bias"
))
?
;
varmap
.set_one
(
format!
(
"{target_prefix}-fc{i}.bias"
),
((
tau
*
network_b
)
?
+
((
1.0
-
tau
)
*
target_b
)
?
)
?
,
)
?
;
}
Ok
(())
}
struct
Actor
<
'a
>
{
varmap
:
VarMap
,
vb
:
VarBuilder
<
'a
>
,
network
:
Sequential
,
target_network
:
Sequential
,
size_state
:
usize
,
size_action
:
usize
,
dims
:
Vec
<
(
usize
,
usize
)
>
,
}
impl
Actor
<
'_
>
{
fn
new
(
device
:
&
Device
,
dtype
:
DType
,
size_state
:
usize
,
size_action
:
usize
)
->
Result
<
Self
>
{
let
mut
varmap
=
VarMap
::
new
();
let
vb
=
VarBuilder
::
from_varmap
(
&
varmap
,
dtype
,
device
);
let
dims
=
vec!
[(
size_state
,
400
),
(
400
,
300
),
(
300
,
size_action
)];
let
make_network
=
|
prefix
:
&
str
|
{
let
seq
=
seq
()
.add
(
linear
(
dims
[
0
]
.0
,
dims
[
0
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc0"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
1
]
.0
,
dims
[
1
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc1"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
2
]
.0
,
dims
[
2
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc2"
)),
)
?
)
.add
(
func
(|
xs
|
xs
.tanh
()));
Ok
::
<
Sequential
,
Error
>
(
seq
)
};
let
network
=
make_network
(
"actor"
)
?
;
let
target_network
=
make_network
(
"target-actor"
)
?
;
// this sets the two networks to be equal to each other using tau = 1.0
track
(
&
mut
varmap
,
&
vb
,
"target-actor"
,
"actor"
,
&
dims
,
1.0
);
Ok
(
Self
{
varmap
,
vb
,
network
,
target_network
,
size_state
,
size_action
,
dims
,
})
}
fn
forward
(
&
self
,
state
:
&
Tensor
)
->
Result
<
Tensor
>
{
self
.network
.forward
(
state
)
}
fn
target_forward
(
&
self
,
state
:
&
Tensor
)
->
Result
<
Tensor
>
{
self
.target_network
.forward
(
state
)
}
fn
track
(
&
mut
self
,
tau
:
f64
)
->
Result
<
()
>
{
track
(
&
mut
self
.varmap
,
&
self
.vb
,
"target-actor"
,
"actor"
,
&
self
.dims
,
tau
,
)
}
}
struct
Critic
<
'a
>
{
varmap
:
VarMap
,
vb
:
VarBuilder
<
'a
>
,
network
:
Sequential
,
target_network
:
Sequential
,
size_state
:
usize
,
size_action
:
usize
,
dims
:
Vec
<
(
usize
,
usize
)
>
,
}
impl
Critic
<
'_
>
{
fn
new
(
device
:
&
Device
,
dtype
:
DType
,
size_state
:
usize
,
size_action
:
usize
)
->
Result
<
Self
>
{
let
mut
varmap
=
VarMap
::
new
();
let
vb
=
VarBuilder
::
from_varmap
(
&
varmap
,
dtype
,
device
);
let
dims
:
Vec
<
(
usize
,
usize
)
>
=
vec!
[(
size_state
+
size_action
,
400
),
(
400
,
300
),
(
300
,
1
)];
let
make_network
=
|
prefix
:
&
str
|
{
let
seq
=
seq
()
.add
(
linear
(
dims
[
0
]
.0
,
dims
[
0
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc0"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
1
]
.0
,
dims
[
1
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc1"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
2
]
.0
,
dims
[
2
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc2"
)),
)
?
);
Ok
::
<
Sequential
,
Error
>
(
seq
)
};
let
network
=
make_network
(
"critic"
)
?
;
let
target_network
=
make_network
(
"target-critic"
)
?
;
// this sets the two networks to be equal to each other using tau = 1.0
track
(
&
mut
varmap
,
&
vb
,
"target-critic"
,
"critic"
,
&
dims
,
1.0
);
Ok
(
Self
{
varmap
,
vb
,
network
,
target_network
,
size_state
,
size_action
,
dims
,
})
}
fn
forward
(
&
self
,
state
:
&
Tensor
,
action
:
&
Tensor
)
->
Result
<
Tensor
>
{
let
xs
=
Tensor
::
cat
(
&
[
action
,
state
],
1
)
?
;
self
.network
.forward
(
&
xs
)
}
fn
target_forward
(
&
self
,
state
:
&
Tensor
,
action
:
&
Tensor
)
->
Result
<
Tensor
>
{
let
xs
=
Tensor
::
cat
(
&
[
action
,
state
],
1
)
?
;
self
.target_network
.forward
(
&
xs
)
}
fn
track
(
&
mut
self
,
tau
:
f64
)
->
Result
<
()
>
{
track
(
&
mut
self
.varmap
,
&
self
.vb
,
"target-critic"
,
"critic"
,
&
self
.dims
,
tau
,
)
}
}
#[allow(clippy::upper_case_acronyms)]
pub
struct
DDPG
<
'a
>
{
actor
:
Actor
<
'a
>
,
actor_optim
:
AdamW
,
critic
:
Critic
<
'a
>
,
critic_optim
:
AdamW
,
gamma
:
f64
,
tau
:
f64
,
replay_buffer
:
ReplayBuffer
,
ou_noise
:
OuNoise
,
size_state
:
usize
,
size_action
:
usize
,
pub
train
:
bool
,
}
impl
DDPG
<
'_
>
{
#[allow(clippy::too_many_arguments)]
pub
fn
new
(
device
:
&
Device
,
size_state
:
usize
,
size_action
:
usize
,
train
:
bool
,
actor_lr
:
f64
,
critic_lr
:
f64
,
gamma
:
f64
,
tau
:
f64
,
buffer_capacity
:
usize
,
ou_noise
:
OuNoise
,
)
->
Result
<
Self
>
{
let
filter_by_prefix
=
|
varmap
:
&
VarMap
,
prefix
:
&
str
|
{
varmap
.data
()
.lock
()
.unwrap
()
.iter
()
.filter_map
(|(
name
,
var
)|
name
.starts_with
(
prefix
)
.then_some
(
var
.clone
()))
.collect
::
<
Vec
<
Var
>>
()
};
let
actor
=
Actor
::
new
(
device
,
DType
::
F32
,
size_state
,
size_action
)
?
;
let
actor_optim
=
AdamW
::
new
(
filter_by_prefix
(
&
actor
.varmap
,
"actor"
),
ParamsAdamW
{
lr
:
actor_lr
,
..
Default
::
default
()
},
)
?
;
let
critic
=
Critic
::
new
(
device
,
DType
::
F32
,
size_state
,
size_action
)
?
;
let
critic_optim
=
AdamW
::
new
(
filter_by_prefix
(
&
critic
.varmap
,
"critic"
),
ParamsAdamW
{
lr
:
critic_lr
,
..
Default
::
default
()
},
)
?
;
Ok
(
Self
{
actor
,
actor_optim
,
critic
,
critic_optim
,
gamma
,
tau
,
replay_buffer
:
ReplayBuffer
::
new
(
buffer_capacity
),
ou_noise
,
size_state
,
size_action
,
train
,
})
}
pub
fn
remember
(
&
mut
self
,
state
:
&
Tensor
,
action
:
&
Tensor
,
reward
:
&
Tensor
,
next_state
:
&
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
)
{
self
.replay_buffer
.push
(
state
,
action
,
reward
,
next_state
,
terminated
,
truncated
)
}
pub
fn
actions
(
&
mut
self
,
state
:
&
Tensor
)
->
Result
<
f32
>
{
let
actions
=
self
.actor
.forward
(
&
state
.detach
()
.unsqueeze
(
0
)
?
)
?
.squeeze
(
0
)
?
;
let
actions
=
if
self
.train
{
(
actions
+
self
.ou_noise
.sample
()
?
)
?
}
else
{
actions
};
actions
.squeeze
(
0
)
?
.to_scalar
::
<
f32
>
()
}
pub
fn
train
(
&
mut
self
,
batch_size
:
usize
)
->
Result
<
()
>
{
let
(
states
,
actions
,
rewards
,
next_states
,
_
,
_
)
=
match
self
.replay_buffer
.random_batch
(
batch_size
)
?
{
Some
(
v
)
=>
v
,
_
=>
return
Ok
(()),
};
let
q_target
=
self
.critic
.target_forward
(
&
next_states
,
&
self
.actor
.target_forward
(
&
next_states
)
?
)
?
;
let
q_target
=
(
rewards
+
(
self
.gamma
*
q_target
)
?
.detach
())
?
;
let
q
=
self
.critic
.forward
(
&
states
,
&
actions
)
?
;
let
diff
=
(
q_target
-
q
)
?
;
let
critic_loss
=
diff
.sqr
()
?
.mean_all
()
?
;
self
.critic_optim
.backward_step
(
&
critic_loss
)
?
;
let
actor_loss
=
self
.critic
.forward
(
&
states
,
&
self
.actor
.forward
(
&
states
)
?
)
?
.mean_all
()
?
.neg
()
?
;
self
.actor_optim
.backward_step
(
&
actor_loss
)
?
;
self
.critic
.track
(
self
.tau
)
?
;
self
.actor
.track
(
self
.tau
)
?
;
Ok
(())
}
}
// The impact of the q value of the next state on the current state's q value.
const
GAMMA
:
f64
=
0.99
;
// The weight for updating the target networks.
const
TAU
:
f64
=
0.005
;
// The capacity of the replay buffer used for sampling training data.
const
REPLAY_BUFFER_CAPACITY
:
usize
=
100_000
;
// The training batch size for each training iteration.
const
TRAINING_BATCH_SIZE
:
usize
=
100
;
// The total number of episodes.
const
MAX_EPISODES
:
usize
=
100
;
// The maximum length of an episode.
const
EPISODE_LENGTH
:
usize
=
200
;
// The number of training iterations after one episode finishes.
const
TRAINING_ITERATIONS
:
usize
=
200
;
// Ornstein-Uhlenbeck process parameters.
const
MU
:
f64
=
0.0
;
const
THETA
:
f64
=
0.15
;
const
SIGMA
:
f64
=
0.1
;
const
ACTOR_LEARNING_RATE
:
f64
=
1e-4
;
const
CRITIC_LEARNING_RATE
:
f64
=
1e-3
;
pub
fn
run
()
->
Result
<
()
>
{
let
env
=
GymEnv
::
new
(
"Pendulum-v1"
)
?
;
println!
(
"action space: {}"
,
env
.action_space
());
println!
(
"observation space: {:?}"
,
env
.observation_space
());
let
size_state
=
env
.observation_space
()
.iter
()
.product
::
<
usize
>
();
let
size_action
=
env
.action_space
();
let
mut
agent
=
DDPG
::
new
(
&
Device
::
Cpu
,
size_state
,
size_action
,
true
,
ACTOR_LEARNING_RATE
,
CRITIC_LEARNING_RATE
,
GAMMA
,
TAU
,
REPLAY_BUFFER_CAPACITY
,
OuNoise
::
new
(
MU
,
THETA
,
SIGMA
,
size_action
)
?
,
)
?
;
let
mut
rng
=
rand
::
thread_rng
();
for
episode
in
0
..
MAX_EPISODES
{
// let mut state = env.reset(episode as u64)?;
let
mut
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
let
mut
total_reward
=
0.0
;
for
_
in
0
..
EPISODE_LENGTH
{
let
mut
action
=
2.0
*
agent
.actions
(
&
state
)
?
;
action
=
action
.clamp
(
-
2.0
,
2.0
);
let
step
=
env
.step
(
vec!
[
action
])
?
;
total_reward
+=
step
.reward
;
agent
.remember
(
&
state
,
&
Tensor
::
new
(
vec!
[
action
],
&
Device
::
Cpu
)
?
,
&
Tensor
::
new
(
vec!
[
step
.reward
as
f32
],
&
Device
::
Cpu
)
?
,
&
step
.state
,
step
.terminated
,
step
.truncated
,
);
if
step
.terminated
||
step
.truncated
{
break
;
}
state
=
step
.state
;
}
println!
(
"episode {episode} with total reward of {total_reward}"
);
for
_
in
0
..
TRAINING_ITERATIONS
{
agent
.train
(
TRAINING_BATCH_SIZE
)
?
;
}
}
println!
(
"Testing..."
);
agent
.train
=
false
;
for
episode
in
0
..
10
{
// let mut state = env.reset(episode as u64)?;
let
mut
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
let
mut
total_reward
=
0.0
;
for
_
in
0
..
EPISODE_LENGTH
{
let
mut
action
=
2.0
*
agent
.actions
(
&
state
)
?
;
action
=
action
.clamp
(
-
2.0
,
2.0
);
let
step
=
env
.step
(
vec!
[
action
])
?
;
total_reward
+=
step
.reward
;
if
step
.terminated
||
step
.truncated
{
break
;
}
state
=
step
.state
;
}
println!
(
"episode {episode} with total reward of {total_reward}"
);
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/dqn.rs
0 → 100644
View file @
25d2752f
use
std
::
collections
::
VecDeque
;
use
rand
::
distributions
::
Uniform
;
use
rand
::{
thread_rng
,
Rng
};
use
candle
::{
DType
,
Device
,
Module
,
Result
,
Tensor
};
use
candle_nn
::
loss
::
mse
;
use
candle_nn
::{
linear
,
seq
,
Activation
,
AdamW
,
Optimizer
,
VarBuilder
,
VarMap
};
use
crate
::
gym_env
::
GymEnv
;
const
DEVICE
:
Device
=
Device
::
Cpu
;
const
EPISODES
:
usize
=
200
;
const
BATCH_SIZE
:
usize
=
64
;
const
GAMMA
:
f64
=
0.99
;
const
LEARNING_RATE
:
f64
=
0.01
;
pub
fn
run
()
->
Result
<
()
>
{
let
env
=
GymEnv
::
new
(
"CartPole-v1"
)
?
;
// Build the model that predicts the estimated rewards given a specific state.
let
var_map
=
VarMap
::
new
();
let
vb
=
VarBuilder
::
from_varmap
(
&
var_map
,
DType
::
F32
,
&
DEVICE
);
let
observation_space
=
*
env
.observation_space
()
.first
()
.unwrap
();
let
model
=
seq
()
.add
(
linear
(
observation_space
,
64
,
vb
.pp
(
"linear_in"
))
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
64
,
env
.action_space
(),
vb
.pp
(
"linear_out"
))
?
);
let
mut
optimizer
=
AdamW
::
new_lr
(
var_map
.all_vars
(),
LEARNING_RATE
)
?
;
// Initialize the model's memory.
let
mut
memory
=
VecDeque
::
with_capacity
(
10000
);
// Start the training loop.
let
mut
state
=
env
.reset
(
0
)
?
;
let
mut
episode
=
0
;
let
mut
accumulate_rewards
=
0.0
;
while
episode
<
EPISODES
{
// Given the current state, predict the estimated rewards, and take the
// action that is expected to return the most rewards.
let
estimated_rewards
=
model
.forward
(
&
state
.unsqueeze
(
0
)
?
)
?
;
let
action
:
u32
=
estimated_rewards
.squeeze
(
0
)
?
.argmax
(
0
)
?
.to_scalar
()
?
;
// Take that action in the environment, and memorize the outcome:
// - the state for which the action was taken
// - the action taken
// - the new state resulting of taking that action
// - the actual rewards of taking that action
// - whether the environment reached a terminal state or not (e.g. game over)
let
step
=
env
.step
(
action
)
?
;
accumulate_rewards
+=
step
.reward
;
memory
.push_back
((
state
,
action
,
step
.state
.clone
(),
step
.reward
,
step
.terminated
||
step
.truncated
,
));
state
=
step
.state
;
// If there's enough entries in the memory, perform a learning step, where
// BATCH_SIZE transitions will be sampled from the memory and will be
// fed to the model so that it performs a backward pass.
if
memory
.len
()
>
BATCH_SIZE
{
// Sample randomly from the memory.
let
batch
=
thread_rng
()
.sample_iter
(
Uniform
::
from
(
0
..
memory
.len
()))
.take
(
BATCH_SIZE
)
.map
(|
i
|
memory
.get
(
i
)
.unwrap
()
.clone
())
.collect
::
<
Vec
<
_
>>
();
// Group all the samples together into tensors with the appropriate shape.
let
states
:
Vec
<
_
>
=
batch
.iter
()
.map
(|
e
|
e
.0
.clone
())
.collect
();
let
states
=
Tensor
::
stack
(
&
states
,
0
)
?
;
let
actions
=
batch
.iter
()
.map
(|
e
|
e
.1
);
let
actions
=
Tensor
::
from_iter
(
actions
,
&
DEVICE
)
?
.unsqueeze
(
1
)
?
;
let
next_states
:
Vec
<
_
>
=
batch
.iter
()
.map
(|
e
|
e
.2
.clone
())
.collect
();
let
next_states
=
Tensor
::
stack
(
&
next_states
,
0
)
?
;
let
rewards
=
batch
.iter
()
.map
(|
e
|
e
.3
as
f32
);
let
rewards
=
Tensor
::
from_iter
(
rewards
,
&
DEVICE
)
?
.unsqueeze
(
1
)
?
;
let
non_final_mask
=
batch
.iter
()
.map
(|
e
|
!
e
.4
as
u8
as
f32
);
let
non_final_mask
=
Tensor
::
from_iter
(
non_final_mask
,
&
DEVICE
)
?
.unsqueeze
(
1
)
?
;
// Get the estimated rewards for the actions that where taken at each step.
let
estimated_rewards
=
model
.forward
(
&
states
)
?
;
let
x
=
estimated_rewards
.gather
(
&
actions
,
1
)
?
;
// Get the maximum expected rewards for the next state, apply them a discount rate
// GAMMA and add them to the rewards that were actually gathered on the current state.
// If the next state is a terminal state, just omit maximum estimated
// rewards for that state.
let
expected_rewards
=
model
.forward
(
&
next_states
)
?
.detach
();
let
y
=
expected_rewards
.max_keepdim
(
1
)
?
;
let
y
=
(
y
*
GAMMA
*
non_final_mask
+
rewards
)
?
;
// Compare the estimated rewards with the maximum expected rewards and
// perform the backward step.
let
loss
=
mse
(
&
x
,
&
y
)
?
;
optimizer
.backward_step
(
&
loss
)
?
;
}
// If we are on a terminal state, reset the environment and log how it went.
if
step
.terminated
||
step
.truncated
{
episode
+=
1
;
println!
(
"Episode {episode} | Rewards {}"
,
accumulate_rewards
as
i64
);
state
=
env
.reset
(
0
)
?
;
accumulate_rewards
=
0.0
;
}
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/gym_env.rs
0 → 100644
View file @
25d2752f
#![allow(unused)]
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
use
candle
::{
Device
,
Result
,
Tensor
};
use
pyo3
::
prelude
::
*
;
use
pyo3
::
types
::
PyDict
;
/// The return value for a step.
#[derive(Debug)]
pub
struct
Step
<
A
>
{
pub
state
:
Tensor
,
pub
action
:
A
,
pub
reward
:
f64
,
pub
terminated
:
bool
,
pub
truncated
:
bool
,
}
impl
<
A
:
Copy
>
Step
<
A
>
{
/// Returns a copy of this step changing the observation tensor.
pub
fn
copy_with_obs
(
&
self
,
state
:
&
Tensor
)
->
Step
<
A
>
{
Step
{
state
:
state
.clone
(),
action
:
self
.action
,
reward
:
self
.reward
,
terminated
:
self
.terminated
,
truncated
:
self
.truncated
,
}
}
}
/// An OpenAI Gym session.
pub
struct
GymEnv
{
env
:
PyObject
,
action_space
:
usize
,
observation_space
:
Vec
<
usize
>
,
}
fn
w
(
res
:
PyErr
)
->
candle
::
Error
{
candle
::
Error
::
wrap
(
res
)
}
impl
GymEnv
{
/// Creates a new session of the specified OpenAI Gym environment.
pub
fn
new
(
name
:
&
str
)
->
Result
<
GymEnv
>
{
Python
::
with_gil
(|
py
|
{
let
gym
=
py
.import_bound
(
"gymnasium"
)
?
;
let
make
=
gym
.getattr
(
"make"
)
?
;
let
env
=
make
.call1
((
name
,))
?
;
let
action_space
=
env
.getattr
(
"action_space"
)
?
;
let
action_space
=
if
let
Ok
(
val
)
=
action_space
.getattr
(
"n"
)
{
val
.extract
()
?
}
else
{
let
action_space
:
Vec
<
usize
>
=
action_space
.getattr
(
"shape"
)
?
.extract
()
?
;
action_space
[
0
]
};
let
observation_space
=
env
.getattr
(
"observation_space"
)
?
;
let
observation_space
=
observation_space
.getattr
(
"shape"
)
?
.extract
()
?
;
Ok
(
GymEnv
{
env
:
env
.into
(),
action_space
,
observation_space
,
})
})
.map_err
(
w
)
}
/// Resets the environment, returning the observation tensor.
pub
fn
reset
(
&
self
,
seed
:
u64
)
->
Result
<
Tensor
>
{
let
state
:
Vec
<
f32
>
=
Python
::
with_gil
(|
py
|
{
let
kwargs
=
PyDict
::
new_bound
(
py
);
kwargs
.set_item
(
"seed"
,
seed
)
?
;
let
state
=
self
.env
.call_method_bound
(
py
,
"reset"
,
(),
Some
(
&
kwargs
))
?
;
state
.bind
(
py
)
.get_item
(
0
)
?
.extract
()
})
.map_err
(
w
)
?
;
Tensor
::
new
(
state
,
&
Device
::
Cpu
)
}
/// Applies an environment step using the specified action.
pub
fn
step
<
A
:
pyo3
::
IntoPy
<
pyo3
::
Py
<
pyo3
::
PyAny
>>
+
Clone
>
(
&
self
,
action
:
A
,
)
->
Result
<
Step
<
A
>>
{
let
(
state
,
reward
,
terminated
,
truncated
)
=
Python
::
with_gil
(|
py
|
{
let
step
=
self
.env
.call_method_bound
(
py
,
"step"
,
(
action
.clone
(),),
None
)
?
;
let
step
=
step
.bind
(
py
);
let
state
:
Vec
<
f32
>
=
step
.get_item
(
0
)
?
.extract
()
?
;
let
reward
:
f64
=
step
.get_item
(
1
)
?
.extract
()
?
;
let
terminated
:
bool
=
step
.get_item
(
2
)
?
.extract
()
?
;
let
truncated
:
bool
=
step
.get_item
(
3
)
?
.extract
()
?
;
Ok
((
state
,
reward
,
terminated
,
truncated
))
})
.map_err
(
w
)
?
;
let
state
=
Tensor
::
new
(
state
,
&
Device
::
Cpu
)
?
;
Ok
(
Step
{
state
,
action
,
reward
,
terminated
,
truncated
,
})
}
/// Returns the number of allowed actions for this environment.
pub
fn
action_space
(
&
self
)
->
usize
{
self
.action_space
}
/// Returns the shape of the observation tensors.
pub
fn
observation_space
(
&
self
)
->
&
[
usize
]
{
&
self
.observation_space
}
}
candle-examples/examples/reinforcement-learning/main.rs
0 → 100644
View file @
25d2752f
#![allow(unused)]
#[cfg(any(feature
=
"mkl"
,
feature
=
"mkl-dynamic"
))]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
candle
::
Result
;
use
clap
::{
Parser
,
Subcommand
};
mod
gym_env
;
mod
vec_gym_env
;
mod
ddpg
;
mod
dqn
;
mod
policy_gradient
;
#[derive(Parser)]
struct
Args
{
#[command(subcommand)]
command
:
Command
,
}
#[derive(Subcommand)]
enum
Command
{
Pg
,
Ddpg
,
Dqn
,
}
fn
main
()
->
Result
<
()
>
{
let
args
=
Args
::
parse
();
match
args
.command
{
Command
::
Pg
=>
policy_gradient
::
run
()
?
,
Command
::
Ddpg
=>
ddpg
::
run
()
?
,
Command
::
Dqn
=>
dqn
::
run
()
?
,
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/policy_gradient.rs
0 → 100644
View file @
25d2752f
use
super
::
gym_env
::{
GymEnv
,
Step
};
use
candle
::{
DType
,
Device
,
Error
,
Module
,
Result
,
Tensor
};
use
candle_nn
::{
linear
,
ops
::
log_softmax
,
ops
::
softmax
,
sequential
::
seq
,
Activation
,
AdamW
,
Optimizer
,
ParamsAdamW
,
VarBuilder
,
VarMap
,
};
use
rand
::{
distributions
::
Distribution
,
rngs
::
ThreadRng
,
Rng
};
fn
new_model
(
input_shape
:
&
[
usize
],
num_actions
:
usize
,
dtype
:
DType
,
device
:
&
Device
,
)
->
Result
<
(
impl
Module
,
VarMap
)
>
{
let
input_size
=
input_shape
.iter
()
.product
();
let
mut
varmap
=
VarMap
::
new
();
let
var_builder
=
VarBuilder
::
from_varmap
(
&
varmap
,
dtype
,
device
);
let
model
=
seq
()
.add
(
linear
(
input_size
,
32
,
var_builder
.pp
(
"lin1"
))
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
32
,
num_actions
,
var_builder
.pp
(
"lin2"
))
?
);
Ok
((
model
,
varmap
))
}
fn
accumulate_rewards
(
steps
:
&
[
Step
<
i64
>
])
->
Vec
<
f64
>
{
let
mut
rewards
:
Vec
<
f64
>
=
steps
.iter
()
.map
(|
s
|
s
.reward
)
.collect
();
let
mut
acc_reward
=
0f64
;
for
(
i
,
reward
)
in
rewards
.iter_mut
()
.enumerate
()
.rev
()
{
if
steps
[
i
]
.terminated
{
acc_reward
=
0.0
;
}
acc_reward
+=
*
reward
;
*
reward
=
acc_reward
;
}
rewards
}
fn
weighted_sample
(
probs
:
Vec
<
f32
>
,
rng
:
&
mut
ThreadRng
)
->
Result
<
usize
>
{
let
distribution
=
rand
::
distributions
::
WeightedIndex
::
new
(
probs
)
.map_err
(
Error
::
wrap
)
?
;
let
mut
rng
=
rng
;
Ok
(
distribution
.sample
(
&
mut
rng
))
}
pub
fn
run
()
->
Result
<
()
>
{
let
env
=
GymEnv
::
new
(
"CartPole-v1"
)
?
;
println!
(
"action space: {:?}"
,
env
.action_space
());
println!
(
"observation space: {:?}"
,
env
.observation_space
());
let
(
model
,
varmap
)
=
new_model
(
env
.observation_space
(),
env
.action_space
(),
DType
::
F32
,
&
Device
::
Cpu
,
)
?
;
let
optimizer_params
=
ParamsAdamW
{
lr
:
0.01
,
weight_decay
:
0.01
,
..
Default
::
default
()
};
let
mut
optimizer
=
AdamW
::
new
(
varmap
.all_vars
(),
optimizer_params
)
?
;
let
mut
rng
=
rand
::
thread_rng
();
for
epoch_idx
in
0
..
100
{
let
mut
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
let
mut
steps
:
Vec
<
Step
<
i64
>>
=
vec!
[];
loop
{
let
action
=
{
let
action_probs
:
Vec
<
f32
>
=
softmax
(
&
model
.forward
(
&
state
.detach
()
.unsqueeze
(
0
)
?
)
?
,
1
)
?
.squeeze
(
0
)
?
.to_vec1
()
?
;
weighted_sample
(
action_probs
,
&
mut
rng
)
?
as
i64
};
let
step
=
env
.step
(
action
)
?
;
steps
.push
(
step
.copy_with_obs
(
&
state
));
if
step
.terminated
||
step
.truncated
{
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
if
steps
.len
()
>
5000
{
break
;
}
}
else
{
state
=
step
.state
;
}
}
let
total_reward
:
f64
=
steps
.iter
()
.map
(|
s
|
s
.reward
)
.sum
();
let
episodes
:
i64
=
steps
.iter
()
.map
(|
s
|
(
s
.terminated
||
s
.truncated
)
as
i64
)
.sum
();
println!
(
"epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}"
,
epoch_idx
,
episodes
,
total_reward
/
episodes
as
f64
);
let
batch_size
=
steps
.len
();
let
rewards
=
Tensor
::
from_vec
(
accumulate_rewards
(
&
steps
),
batch_size
,
&
Device
::
Cpu
)
?
.to_dtype
(
DType
::
F32
)
?
.detach
();
let
actions_mask
=
{
let
actions
:
Vec
<
i64
>
=
steps
.iter
()
.map
(|
s
|
s
.action
)
.collect
();
let
actions_mask
:
Vec
<
Tensor
>
=
actions
.iter
()
.map
(|
&
action
|
{
// One-hot encoding
let
mut
action_mask
=
vec!
[
0.0
;
env
.action_space
()];
action_mask
[
action
as
usize
]
=
1.0
;
Tensor
::
from_vec
(
action_mask
,
env
.action_space
(),
&
Device
::
Cpu
)
.unwrap
()
.to_dtype
(
DType
::
F32
)
.unwrap
()
})
.collect
();
Tensor
::
stack
(
&
actions_mask
,
0
)
?
.detach
()
};
let
states
=
{
let
states
:
Vec
<
Tensor
>
=
steps
.into_iter
()
.map
(|
s
|
s
.state
)
.collect
();
Tensor
::
stack
(
&
states
,
0
)
?
.detach
()
};
let
log_probs
=
actions_mask
.mul
(
&
log_softmax
(
&
model
.forward
(
&
states
)
?
,
1
)
?
)
?
.sum
(
1
)
?
;
let
loss
=
rewards
.mul
(
&
log_probs
)
?
.neg
()
?
.mean_all
()
?
;
optimizer
.backward_step
(
&
loss
)
?
;
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/vec_gym_env.rs
0 → 100644
View file @
25d2752f
#![allow(unused)]
//! Vectorized version of the gym environment.
use
candle
::{
DType
,
Device
,
Result
,
Tensor
};
use
pyo3
::
prelude
::
*
;
use
pyo3
::
types
::
PyDict
;
#[derive(Debug)]
pub
struct
Step
{
pub
obs
:
Tensor
,
pub
reward
:
Tensor
,
pub
is_done
:
Tensor
,
}
pub
struct
VecGymEnv
{
env
:
PyObject
,
action_space
:
usize
,
observation_space
:
Vec
<
usize
>
,
}
fn
w
(
res
:
PyErr
)
->
candle
::
Error
{
candle
::
Error
::
wrap
(
res
)
}
impl
VecGymEnv
{
pub
fn
new
(
name
:
&
str
,
img_dir
:
Option
<&
str
>
,
nprocesses
:
usize
)
->
Result
<
VecGymEnv
>
{
Python
::
with_gil
(|
py
|
{
let
sys
=
py
.import_bound
(
"sys"
)
?
;
let
path
=
sys
.getattr
(
"path"
)
?
;
let
_
=
path
.call_method1
(
"append"
,
(
"candle-examples/examples/reinforcement-learning"
,),
)
?
;
let
gym
=
py
.import_bound
(
"atari_wrappers"
)
?
;
let
make
=
gym
.getattr
(
"make"
)
?
;
let
env
=
make
.call1
((
name
,
img_dir
,
nprocesses
))
?
;
let
action_space
=
env
.getattr
(
"action_space"
)
?
;
let
action_space
=
action_space
.getattr
(
"n"
)
?
.extract
()
?
;
let
observation_space
=
env
.getattr
(
"observation_space"
)
?
;
let
observation_space
:
Vec
<
usize
>
=
observation_space
.getattr
(
"shape"
)
?
.extract
()
?
;
let
observation_space
=
[
vec!
[
nprocesses
]
.as_slice
(),
observation_space
.as_slice
()]
.concat
();
Ok
(
VecGymEnv
{
env
:
env
.into
(),
action_space
,
observation_space
,
})
})
.map_err
(
w
)
}
pub
fn
reset
(
&
self
)
->
Result
<
Tensor
>
{
let
obs
=
Python
::
with_gil
(|
py
|
{
let
obs
=
self
.env
.call_method0
(
py
,
"reset"
)
?
;
let
obs
=
obs
.call_method0
(
py
,
"flatten"
)
?
;
obs
.extract
::
<
Vec
<
f32
>>
(
py
)
})
.map_err
(
w
)
?
;
Tensor
::
new
(
obs
,
&
Device
::
Cpu
)
?
.reshape
(
self
.observation_space
.as_slice
())
}
pub
fn
step
(
&
self
,
action
:
Vec
<
usize
>
)
->
Result
<
Step
>
{
let
(
obs
,
reward
,
is_done
)
=
Python
::
with_gil
(|
py
|
{
let
step
=
self
.env
.call_method_bound
(
py
,
"step"
,
(
action
,),
None
)
?
;
let
step
=
step
.bind
(
py
);
let
obs
=
step
.get_item
(
0
)
?
.call_method
(
"flatten"
,
(),
None
)
?
;
let
obs_buffer
=
pyo3
::
buffer
::
PyBuffer
::
get_bound
(
&
obs
)
?
;
let
obs
:
Vec
<
u8
>
=
obs_buffer
.to_vec
(
py
)
?
;
let
reward
:
Vec
<
f32
>
=
step
.get_item
(
1
)
?
.extract
()
?
;
let
is_done
:
Vec
<
f32
>
=
step
.get_item
(
2
)
?
.extract
()
?
;
Ok
((
obs
,
reward
,
is_done
))
})
.map_err
(
w
)
?
;
let
obs
=
Tensor
::
from_vec
(
obs
,
self
.observation_space
.as_slice
(),
&
Device
::
Cpu
)
?
.to_dtype
(
DType
::
F32
)
?
;
let
reward
=
Tensor
::
new
(
reward
,
&
Device
::
Cpu
)
?
;
let
is_done
=
Tensor
::
new
(
is_done
,
&
Device
::
Cpu
)
?
;
Ok
(
Step
{
obs
,
reward
,
is_done
,
})
}
pub
fn
action_space
(
&
self
)
->
usize
{
self
.action_space
}
pub
fn
observation_space
(
&
self
)
->
&
[
usize
]
{
&
self
.observation_space
}
}
candle-examples/examples/replit-code/README.md
0 → 100644
View file @
25d2752f
# candle-replit-code: code completion specialized model.
[
replit-code-v1_5-3b
](
https://huggingface.co/replit/replit-code-v1_5-3b
)
is a
language model specialized for code completion. This model uses 3.3B parameters
in
`bfloat16`
(so the GPU version will only work on recent nvidia cards).
## Running some example
```
bash
cargo run
--example
replit-code
--release
--
--prompt
'def fibonacci(n): '
```
This produces the following output.
```
def fibonacci(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
print(a, end=' ')
a, b = b, a+b
print()
def fibonacci_loop(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
result = []
a, b = 0, 1
while a < n:
result.append(a)
a, b = b, a+b
return result
def fibonacci_generator(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
yield a
a, b = b, a+b
```
candle-examples/examples/replit-code/main.rs
0 → 100644
View file @
25d2752f
#[cfg(any(feature
=
"mkl"
,
feature
=
"mkl-dynamic"
))]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
anyhow
::{
Error
as
E
,
Result
};
use
clap
::
Parser
;
use
candle_transformers
::
models
::
mpt
::{
Config
,
Model
as
M
};
use
candle_transformers
::
models
::
quantized_mpt
::
Model
as
Q
;
use
candle
::{
DType
,
Device
,
Tensor
};
use
candle_nn
::
VarBuilder
;
use
candle_transformers
::
generation
::
LogitsProcessor
;
use
hf_hub
::{
api
::
sync
::
Api
,
Repo
,
RepoType
};
use
tokenizers
::
Tokenizer
;
enum
Model
{
M
(
M
),
Q
(
Q
),
}
impl
Model
{
fn
forward
(
&
mut
self
,
xs
:
&
Tensor
)
->
candle
::
Result
<
Tensor
>
{
match
self
{
Self
::
M
(
model
)
=>
model
.forward
(
xs
),
Self
::
Q
(
model
)
=>
model
.forward
(
xs
),
}
}
}
struct
TextGeneration
{
model
:
Model
,
device
:
Device
,
tokenizer
:
Tokenizer
,
logits_processor
:
LogitsProcessor
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
verbose_prompt
:
bool
,
}
impl
TextGeneration
{
#[allow(clippy::too_many_arguments)]
fn
new
(
model
:
Model
,
tokenizer
:
Tokenizer
,
seed
:
u64
,
temp
:
Option
<
f64
>
,
top_p
:
Option
<
f64
>
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
verbose_prompt
:
bool
,
device
:
&
Device
,
)
->
Self
{
let
logits_processor
=
LogitsProcessor
::
new
(
seed
,
temp
,
top_p
);
Self
{
model
,
tokenizer
,
logits_processor
,
repeat_penalty
,
repeat_last_n
,
verbose_prompt
,
device
:
device
.clone
(),
}
}
fn
run
(
&
mut
self
,
prompt
:
&
str
,
sample_len
:
usize
)
->
Result
<
()
>
{
use
std
::
io
::
Write
;
println!
(
"starting the inference loop"
);
let
tokens
=
self
.tokenizer
.encode
(
prompt
,
true
)
.map_err
(
E
::
msg
)
?
;
if
tokens
.is_empty
()
{
anyhow
::
bail!
(
"Empty prompts are not supported in the phi model."
)
}
if
self
.verbose_prompt
{
for
(
token
,
id
)
in
tokens
.get_tokens
()
.iter
()
.zip
(
tokens
.get_ids
()
.iter
())
{
let
token
=
token
.replace
(
'▁'
,
" "
)
.replace
(
"<0x0A>"
,
"
\n
"
);
println!
(
"{id:7} -> '{token}'"
);
}
}
let
mut
tokens
=
tokens
.get_ids
()
.to_vec
();
let
mut
generated_tokens
=
0u
size
;
let
eos_token
=
match
self
.tokenizer
.get_vocab
(
true
)
.get
(
"<|endoftext|>"
)
{
Some
(
token
)
=>
*
token
,
None
=>
anyhow
::
bail!
(
"cannot find the endoftext token"
),
};
print!
(
"{prompt}"
);
std
::
io
::
stdout
()
.flush
()
?
;
let
start_gen
=
std
::
time
::
Instant
::
now
();
for
index
in
0
..
sample_len
{
let
context_size
=
if
index
>
0
{
1
}
else
{
tokens
.len
()
};
let
ctxt
=
&
tokens
[
tokens
.len
()
.saturating_sub
(
context_size
)
..
];
let
input
=
Tensor
::
new
(
ctxt
,
&
self
.device
)
?
.unsqueeze
(
0
)
?
;
let
logits
=
self
.model
.forward
(
&
input
)
?
;
let
logits
=
logits
.squeeze
(
0
)
?
.to_dtype
(
DType
::
F32
)
?
;
let
logits
=
if
self
.repeat_penalty
==
1
.
{
logits
}
else
{
let
start_at
=
tokens
.len
()
.saturating_sub
(
self
.repeat_last_n
);
candle_transformers
::
utils
::
apply_repeat_penalty
(
&
logits
,
self
.repeat_penalty
,
&
tokens
[
start_at
..
],
)
?
};
let
next_token
=
self
.logits_processor
.sample
(
&
logits
)
?
;
tokens
.push
(
next_token
);
generated_tokens
+=
1
;
if
next_token
==
eos_token
{
break
;
}
let
token
=
self
.tokenizer
.decode
(
&
[
next_token
],
true
)
.map_err
(
E
::
msg
)
?
;
print!
(
"{token}"
);
std
::
io
::
stdout
()
.flush
()
?
;
}
let
dt
=
start_gen
.elapsed
();
println!
(
"
\n
{generated_tokens} tokens generated ({:.2} token/s)"
,
generated_tokens
as
f64
/
dt
.as_secs_f64
(),
);
Ok
(())
}
}
#[derive(Parser,
Debug)]
#[command(author,
version,
about,
long_about
=
None)]
struct
Args
{
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing
:
bool
,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt
:
bool
,
#[arg(long)]
prompt
:
String
,
/// The temperature used to generate samples.
#[arg(long)]
temperature
:
Option
<
f64
>
,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p
:
Option
<
f64
>
,
/// The seed to use when generating random samples.
#[arg(long,
default_value_t
=
299792458
)]
seed
:
u64
,
/// The length of the sample to generate (in tokens).
#[arg(long,
short
=
'n'
,
default_value_t
=
1000
)]
sample_len
:
usize
,
#[arg(long)]
model_id
:
Option
<
String
>
,
#[arg(long)]
revision
:
Option
<
String
>
,
#[arg(long)]
quantized
:
bool
,
#[arg(long)]
weight_file
:
Option
<
String
>
,
#[arg(long)]
tokenizer
:
Option
<
String
>
,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long,
default_value_t
=
1
.
)]
repeat_penalty
:
f32
,
/// The context size to consider for the repeat penalty.
#[arg(long,
default_value_t
=
64
)]
repeat_last_n
:
usize
,
}
fn
main
()
->
Result
<
()
>
{
use
tracing_chrome
::
ChromeLayerBuilder
;
use
tracing_subscriber
::
prelude
::
*
;
let
args
=
Args
::
parse
();
let
_
guard
=
if
args
.tracing
{
let
(
chrome_layer
,
guard
)
=
ChromeLayerBuilder
::
new
()
.build
();
tracing_subscriber
::
registry
()
.with
(
chrome_layer
)
.init
();
Some
(
guard
)
}
else
{
None
};
println!
(
"avx: {}, neon: {}, simd128: {}, f16c: {}"
,
candle
::
utils
::
with_avx
(),
candle
::
utils
::
with_neon
(),
candle
::
utils
::
with_simd128
(),
candle
::
utils
::
with_f16c
()
);
println!
(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}"
,
args
.temperature
.unwrap_or
(
0
.
),
args
.repeat_penalty
,
args
.repeat_last_n
);
let
start
=
std
::
time
::
Instant
::
now
();
let
api
=
Api
::
new
()
?
;
let
model_id
=
match
args
.model_id
{
Some
(
model_id
)
=>
model_id
.to_string
(),
None
=>
"lmz/candle-replit-code"
.to_string
(),
};
let
revision
=
match
args
.revision
{
Some
(
rev
)
=>
rev
.to_string
(),
None
=>
"main"
.to_string
(),
};
let
repo
=
api
.repo
(
Repo
::
with_revision
(
model_id
,
RepoType
::
Model
,
revision
));
let
tokenizer_filename
=
match
args
.tokenizer
{
Some
(
file
)
=>
std
::
path
::
PathBuf
::
from
(
file
),
None
=>
repo
.get
(
"tokenizer.json"
)
?
,
};
let
filename
=
match
args
.weight_file
{
Some
(
weight_file
)
=>
std
::
path
::
PathBuf
::
from
(
weight_file
),
None
=>
{
if
args
.quantized
{
repo
.get
(
"model-replit-code-v1_5-q4k.gguf"
)
?
}
else
{
repo
.get
(
"model.safetensors"
)
?
}
}
};
println!
(
"retrieved the files in {:?}"
,
start
.elapsed
());
let
tokenizer
=
Tokenizer
::
from_file
(
tokenizer_filename
)
.map_err
(
E
::
msg
)
?
;
let
start
=
std
::
time
::
Instant
::
now
();
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
config
=
Config
::
replit_code_v1_5_3b
();
let
model
=
if
args
.quantized
{
let
vb
=
candle_transformers
::
quantized_var_builder
::
VarBuilder
::
from_gguf
(
&
filename
,
&
device
)
?
;
Model
::
Q
(
Q
::
new
(
&
config
,
vb
.pp
(
"transformer"
))
?
)
}
else
{
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
[
filename
],
DType
::
F32
,
&
device
)
?
};
Model
::
M
(
M
::
new
(
&
config
,
vb
.pp
(
"transformer"
))
?
)
};
println!
(
"loaded the model in {:?}"
,
start
.elapsed
());
let
mut
pipeline
=
TextGeneration
::
new
(
model
,
tokenizer
,
args
.seed
,
args
.temperature
,
args
.top_p
,
args
.repeat_penalty
,
args
.repeat_last_n
,
args
.verbose_prompt
,
&
device
,
);
pipeline
.run
(
&
args
.prompt
,
args
.sample_len
)
?
;
Ok
(())
}
candle-examples/examples/repvgg/README.md
0 → 100644
View file @
25d2752f
# candle-repvgg
[
RepVGG: Making VGG-style ConvNets Great Again
](
https://arxiv.org/abs/2101.03697
)
.
This candle implementation uses a pre-trained RepVGG network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$
cargo
run
--
example
repvgg
--
release
--
--
image
candle
-
examples
/
examples
/
yolo
-
v8
/
assets
/
bike
.
jpg
loaded
image
Tensor
[
dims
3
,
224
,
224
;
f32
]
model
built
mountain
bike
,
all
-
terrain
bike
,
off
-
roader
:
61.70
%
bicycle
-
built
-
for
-
two
,
tandem
bicycle
,
tandem
:
33.14
%
unicycle
,
monocycle
:
4.88
%
crash
helmet
:
0.15
%
moped
:
0.04
%
```
candle-examples/examples/repvgg/main.rs
0 → 100644
View file @
25d2752f
#[cfg(feature
=
"mkl"
)]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
clap
::{
Parser
,
ValueEnum
};
use
candle
::{
DType
,
IndexOp
,
D
};
use
candle_nn
::{
Module
,
VarBuilder
};
use
candle_transformers
::
models
::
repvgg
;
#[derive(Clone,
Copy,
Debug,
ValueEnum)]
enum
Which
{
A0
,
A1
,
A2
,
B0
,
B1
,
B2
,
B3
,
B1G4
,
B2G4
,
B3G4
,
}
impl
Which
{
fn
model_filename
(
&
self
)
->
String
{
let
name
=
match
self
{
Self
::
A0
=>
"a0"
,
Self
::
A1
=>
"a1"
,
Self
::
A2
=>
"a2"
,
Self
::
B0
=>
"b0"
,
Self
::
B1
=>
"b1"
,
Self
::
B2
=>
"b2"
,
Self
::
B3
=>
"b3"
,
Self
::
B1G4
=>
"b1g4"
,
Self
::
B2G4
=>
"b2g4"
,
Self
::
B3G4
=>
"b3g4"
,
};
format!
(
"timm/repvgg_{}.rvgg_in1k"
,
name
)
}
fn
config
(
&
self
)
->
repvgg
::
Config
{
match
self
{
Self
::
A0
=>
repvgg
::
Config
::
a0
(),
Self
::
A1
=>
repvgg
::
Config
::
a1
(),
Self
::
A2
=>
repvgg
::
Config
::
a2
(),
Self
::
B0
=>
repvgg
::
Config
::
b0
(),
Self
::
B1
=>
repvgg
::
Config
::
b1
(),
Self
::
B2
=>
repvgg
::
Config
::
b2
(),
Self
::
B3
=>
repvgg
::
Config
::
b3
(),
Self
::
B1G4
=>
repvgg
::
Config
::
b1g4
(),
Self
::
B2G4
=>
repvgg
::
Config
::
b2g4
(),
Self
::
B3G4
=>
repvgg
::
Config
::
b3g4
(),
}
}
}
#[derive(Parser)]
struct
Args
{
#[arg(long)]
model
:
Option
<
String
>
,
#[arg(long)]
image
:
String
,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
#[arg(value_enum,
long,
default_value_t=Which::A0)]
which
:
Which
,
}
pub
fn
main
()
->
anyhow
::
Result
<
()
>
{
let
args
=
Args
::
parse
();
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
image
=
candle_examples
::
imagenet
::
load_image224
(
args
.image
)
?
.to_device
(
&
device
)
?
;
println!
(
"loaded image {image:?}"
);
let
model_file
=
match
args
.model
{
None
=>
{
let
model_name
=
args
.which
.model_filename
();
let
api
=
hf_hub
::
api
::
sync
::
Api
::
new
()
?
;
let
api
=
api
.model
(
model_name
);
api
.get
(
"model.safetensors"
)
?
}
Some
(
model
)
=>
model
.into
(),
};
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
[
model_file
],
DType
::
F32
,
&
device
)
?
};
let
model
=
repvgg
::
repvgg
(
&
args
.which
.config
(),
1000
,
vb
)
?
;
println!
(
"model built"
);
let
logits
=
model
.forward
(
&
image
.unsqueeze
(
0
)
?
)
?
;
let
prs
=
candle_nn
::
ops
::
softmax
(
&
logits
,
D
::
Minus1
)
?
.i
(
0
)
?
.to_vec1
::
<
f32
>
()
?
;
let
mut
prs
=
prs
.iter
()
.enumerate
()
.collect
::
<
Vec
<
_
>>
();
prs
.sort_by
(|(
_
,
p1
),
(
_
,
p2
)|
p2
.total_cmp
(
p1
));
for
&
(
category_idx
,
pr
)
in
prs
.iter
()
.take
(
5
)
{
println!
(
"{:24}: {:.2}%"
,
candle_examples
::
imagenet
::
CLASSES
[
category_idx
],
100
.
*
pr
);
}
Ok
(())
}
candle-examples/examples/resnet/README.md
0 → 100644
View file @
25d2752f
# candle-resnet
A candle implementation of inference using a pre-trained
[
ResNet
](
https://arxiv.org/abs/1512.03385
)
.
This uses a classification head trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$
cargo
run
--
example
resnet
--
release
--
--
image
tiger
.
jpg
loaded
image
Tensor
[
dims
3
,
224
,
224
;
f32
]
model
built
tiger
,
Panthera
tigris
:
90.21
%
tiger
cat
:
8.93
%
lion
,
king
of
beasts
,
Panthera
leo
:
0.35
%
leopard
,
Panthera
pardus
:
0.16
%
jaguar
,
panther
,
Panthera
onca
,
Felis
onca
:
0.09
%
```
candle-examples/examples/resnet/export_models.py
0 → 100644
View file @
25d2752f
# This script exports pre-trained model weights in the safetensors format.
import
numpy
as
np
import
torch
import
torchvision
from
safetensors
import
torch
as
stt
m
=
torchvision
.
models
.
resnet50
(
pretrained
=
True
)
stt
.
save_file
(
m
.
state_dict
(),
'resnet50.safetensors'
)
m
=
torchvision
.
models
.
resnet101
(
pretrained
=
True
)
stt
.
save_file
(
m
.
state_dict
(),
'resnet101.safetensors'
)
m
=
torchvision
.
models
.
resnet152
(
pretrained
=
True
)
stt
.
save_file
(
m
.
state_dict
(),
'resnet152.safetensors'
)
candle-examples/examples/resnet/main.rs
0 → 100644
View file @
25d2752f
#[cfg(any(feature
=
"mkl"
,
feature
=
"mkl-dynamic"
))]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
candle
::{
DType
,
IndexOp
,
D
};
use
candle_nn
::{
Module
,
VarBuilder
};
use
candle_transformers
::
models
::
resnet
;
use
clap
::{
Parser
,
ValueEnum
};
#[derive(Clone,
Copy,
Debug,
ValueEnum)]
enum
Which
{
#[value(name
=
"18"
)]
Resnet18
,
#[value(name
=
"34"
)]
Resnet34
,
#[value(name
=
"50"
)]
Resnet50
,
#[value(name
=
"101"
)]
Resnet101
,
#[value(name
=
"152"
)]
Resnet152
,
}
#[derive(Parser)]
struct
Args
{
#[arg(long)]
model
:
Option
<
String
>
,
#[arg(long)]
image
:
String
,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
/// Variant of the model to use.
#[arg(value_enum,
long,
default_value_t
=
Which::Resnet18)]
which
:
Which
,
}
pub
fn
main
()
->
anyhow
::
Result
<
()
>
{
let
args
=
Args
::
parse
();
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
image
=
candle_examples
::
imagenet
::
load_image224
(
args
.image
)
?
.to_device
(
&
device
)
?
;
println!
(
"loaded image {image:?}"
);
let
model_file
=
match
args
.model
{
None
=>
{
let
api
=
hf_hub
::
api
::
sync
::
Api
::
new
()
?
;
let
api
=
api
.model
(
"lmz/candle-resnet"
.into
());
let
filename
=
match
args
.which
{
Which
::
Resnet18
=>
"resnet18.safetensors"
,
Which
::
Resnet34
=>
"resnet34.safetensors"
,
Which
::
Resnet50
=>
"resnet50.safetensors"
,
Which
::
Resnet101
=>
"resnet101.safetensors"
,
Which
::
Resnet152
=>
"resnet152.safetensors"
,
};
api
.get
(
filename
)
?
}
Some
(
model
)
=>
model
.into
(),
};
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
[
model_file
],
DType
::
F32
,
&
device
)
?
};
let
class_count
=
candle_examples
::
imagenet
::
CLASS_COUNT
as
usize
;
let
model
=
match
args
.which
{
Which
::
Resnet18
=>
resnet
::
resnet18
(
class_count
,
vb
)
?
,
Which
::
Resnet34
=>
resnet
::
resnet34
(
class_count
,
vb
)
?
,
Which
::
Resnet50
=>
resnet
::
resnet50
(
class_count
,
vb
)
?
,
Which
::
Resnet101
=>
resnet
::
resnet101
(
class_count
,
vb
)
?
,
Which
::
Resnet152
=>
resnet
::
resnet152
(
class_count
,
vb
)
?
,
};
println!
(
"model built"
);
let
logits
=
model
.forward
(
&
image
.unsqueeze
(
0
)
?
)
?
;
let
prs
=
candle_nn
::
ops
::
softmax
(
&
logits
,
D
::
Minus1
)
?
.i
(
0
)
?
.to_vec1
::
<
f32
>
()
?
;
let
mut
prs
=
prs
.iter
()
.enumerate
()
.collect
::
<
Vec
<
_
>>
();
prs
.sort_by
(|(
_
,
p1
),
(
_
,
p2
)|
p2
.total_cmp
(
p1
));
for
&
(
category_idx
,
pr
)
in
prs
.iter
()
.take
(
5
)
{
println!
(
"{:24}: {:.2}%"
,
candle_examples
::
imagenet
::
CLASSES
[
category_idx
],
100
.
*
pr
);
}
Ok
(())
}
candle-examples/examples/rwkv/README.md
0 → 100644
View file @
25d2752f
## candle-rwkv
The
[
RWKV model
](
https://wiki.rwkv.com/
)
is a recurrent neural network model
with performance on par with transformer architectures. Several variants are
available, candle implements the v5 and v6 versions and can be used with
Eagle 7B(
[
blog post
](
https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers
)
).
```
bash
$
cargo run
--example
rwkv
--release
--
--prompt
"The smallest prime is "
avx:
true
, neon:
false
, simd128:
false
, f16c:
true
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
The smallest prime is ϕ
(
2
)
=
2.
The smallest composite is ϕ
(
3
)
=
3.
The smallest perfect number is ϕ
(
5
)
=
5.
The smallest perfect square is ϕ
(
4
)
=
4.
The smallest perfect cube is ϕ
(
6
)
=
6.
```
candle-examples/examples/rwkv/main.rs
0 → 100644
View file @
25d2752f
#[cfg(feature
=
"mkl"
)]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
anyhow
::
Result
;
use
clap
::{
Parser
,
ValueEnum
};
use
candle_transformers
::
models
::
quantized_rwkv_v5
::
Model
as
Q5
;
use
candle_transformers
::
models
::
quantized_rwkv_v6
::
Model
as
Q6
;
use
candle_transformers
::
models
::
rwkv_v5
::{
Config
,
Model
as
M5
,
State
,
Tokenizer
};
use
candle_transformers
::
models
::
rwkv_v6
::
Model
as
M6
;
use
candle
::{
DType
,
Device
,
Tensor
};
use
candle_nn
::
VarBuilder
;
use
candle_transformers
::
generation
::
LogitsProcessor
;
use
hf_hub
::{
api
::
sync
::
Api
,
Repo
,
RepoType
};
const
EOS_TOKEN_ID
:
u32
=
261
;
enum
Model
{
M5
(
M5
),
Q5
(
Q5
),
M6
(
M6
),
Q6
(
Q6
),
}
impl
Model
{
fn
forward
(
&
self
,
xs
:
&
Tensor
,
state
:
&
mut
State
)
->
candle
::
Result
<
Tensor
>
{
match
self
{
Self
::
M5
(
m
)
=>
m
.forward
(
xs
,
state
),
Self
::
Q5
(
m
)
=>
m
.forward
(
xs
,
state
),
Self
::
M6
(
m
)
=>
m
.forward
(
xs
,
state
),
Self
::
Q6
(
m
)
=>
m
.forward
(
xs
,
state
),
}
}
}
struct
TextGeneration
{
model
:
Model
,
config
:
Config
,
device
:
Device
,
tokenizer
:
Tokenizer
,
logits_processor
:
LogitsProcessor
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
}
impl
TextGeneration
{
#[allow(clippy::too_many_arguments)]
fn
new
(
model
:
Model
,
config
:
Config
,
tokenizer
:
Tokenizer
,
seed
:
u64
,
temp
:
Option
<
f64
>
,
top_p
:
Option
<
f64
>
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
device
:
&
Device
,
)
->
Self
{
let
logits_processor
=
LogitsProcessor
::
new
(
seed
,
temp
,
top_p
);
Self
{
model
,
config
,
tokenizer
,
logits_processor
,
repeat_penalty
,
repeat_last_n
,
device
:
device
.clone
(),
}
}
fn
run
(
&
mut
self
,
prompt
:
&
str
,
sample_len
:
usize
)
->
Result
<
()
>
{
use
std
::
io
::
Write
;
let
mut
tokens
=
self
.tokenizer
.encode
(
prompt
)
?
;
let
mut
generated_tokens
=
0u
size
;
let
mut
state
=
State
::
new
(
1
,
&
self
.config
,
&
self
.device
)
?
;
let
mut
next_logits
=
None
;
for
&
t
in
tokens
.iter
()
{
let
input
=
Tensor
::
new
(
&
[[
t
]],
&
self
.device
)
?
;
let
logits
=
self
.model
.forward
(
&
input
,
&
mut
state
)
?
;
next_logits
=
Some
(
logits
);
print!
(
"{}"
,
self
.tokenizer
.decode
(
&
[
t
])
?
)
}
std
::
io
::
stdout
()
.flush
()
?
;
let
start_gen
=
std
::
time
::
Instant
::
now
();
for
_
in
0
..
sample_len
{
let
logits
=
match
next_logits
.as_ref
()
{
Some
(
logits
)
=>
logits
,
None
=>
anyhow
::
bail!
(
"cannot work on an empty prompt"
),
};
let
logits
=
logits
.squeeze
(
0
)
?
.squeeze
(
0
)
?
.to_dtype
(
DType
::
F32
)
?
;
let
logits
=
if
self
.repeat_penalty
==
1
.
{
logits
}
else
{
let
start_at
=
tokens
.len
()
.saturating_sub
(
self
.repeat_last_n
);
candle_transformers
::
utils
::
apply_repeat_penalty
(
&
logits
,
self
.repeat_penalty
,
&
tokens
[
start_at
..
],
)
?
};
let
next_token
=
self
.logits_processor
.sample
(
&
logits
)
?
;
tokens
.push
(
next_token
);
generated_tokens
+=
1
;
if
next_token
==
EOS_TOKEN_ID
||
next_token
==
0
{
break
;
}
print!
(
"{}"
,
self
.tokenizer
.decode
(
&
[
next_token
])
?
);
std
::
io
::
stdout
()
.flush
()
?
;
let
input
=
Tensor
::
new
(
&
[[
next_token
]],
&
self
.device
)
?
;
next_logits
=
Some
(
self
.model
.forward
(
&
input
,
&
mut
state
)
?
)
}
let
dt
=
start_gen
.elapsed
();
println!
(
"
\n
{generated_tokens} tokens generated ({:.2} token/s)"
,
generated_tokens
as
f64
/
dt
.as_secs_f64
(),
);
Ok
(())
}
}
#[derive(Parser,
ValueEnum,
Clone,
Copy,
PartialEq,
Eq,
Debug)]
enum
Which
{
Eagle7b
,
World1b5
,
World3b
,
World6_1b6
,
}
impl
std
::
fmt
::
Display
for
Which
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
write!
(
f
,
"{:?}"
,
self
)
}
}
impl
Which
{
fn
model_id
(
&
self
)
->
&
'static
str
{
match
self
{
Self
::
Eagle7b
=>
"RWKV/v5-Eagle-7B-HF"
,
Self
::
World1b5
=>
"RWKV/rwkv-5-world-1b5"
,
Self
::
World3b
=>
"RWKV/rwkv-5-world-3b"
,
Self
::
World6_1b6
=>
"paperfun/rwkv"
,
}
}
fn
revision
(
&
self
)
->
&
'static
str
{
match
self
{
Self
::
Eagle7b
=>
"refs/pr/1"
,
Self
::
World1b5
|
Self
::
World3b
=>
"refs/pr/2"
,
Self
::
World6_1b6
=>
"main"
,
}
}
}
#[derive(Parser,
Debug)]
#[command(author,
version,
about,
long_about
=
None)]
struct
Args
{
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing
:
bool
,
#[arg(long)]
prompt
:
String
,
/// The temperature used to generate samples.
#[arg(long)]
temperature
:
Option
<
f64
>
,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p
:
Option
<
f64
>
,
/// The seed to use when generating random samples.
#[arg(long,
default_value_t
=
299792458
)]
seed
:
u64
,
/// The length of the sample to generate (in tokens).
#[arg(long,
short
=
'n'
,
default_value_t
=
5000
)]
sample_len
:
usize
,
#[arg(long,
default_value
=
"world1b5"
)]
which
:
Which
,
#[arg(long)]
model_id
:
Option
<
String
>
,
#[arg(long)]
revision
:
Option
<
String
>
,
#[arg(long)]
tokenizer
:
Option
<
String
>
,
#[arg(long)]
weight_files
:
Option
<
String
>
,
#[arg(long)]
config_file
:
Option
<
String
>
,
#[arg(long)]
quantized
:
bool
,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long,
default_value_t
=
1.1
)]
repeat_penalty
:
f32
,
/// The context size to consider for the repeat penalty.
#[arg(long,
default_value_t
=
64
)]
repeat_last_n
:
usize
,
}
fn
main
()
->
Result
<
()
>
{
use
tracing_chrome
::
ChromeLayerBuilder
;
use
tracing_subscriber
::
prelude
::
*
;
let
args
=
Args
::
parse
();
let
_
guard
=
if
args
.tracing
{
let
(
chrome_layer
,
guard
)
=
ChromeLayerBuilder
::
new
()
.build
();
tracing_subscriber
::
registry
()
.with
(
chrome_layer
)
.init
();
Some
(
guard
)
}
else
{
None
};
println!
(
"avx: {}, neon: {}, simd128: {}, f16c: {}"
,
candle
::
utils
::
with_avx
(),
candle
::
utils
::
with_neon
(),
candle
::
utils
::
with_simd128
(),
candle
::
utils
::
with_f16c
()
);
println!
(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}"
,
args
.temperature
.unwrap_or
(
0
.
),
args
.repeat_penalty
,
args
.repeat_last_n
);
let
start
=
std
::
time
::
Instant
::
now
();
let
api
=
Api
::
new
()
?
;
let
repo
=
api
.repo
(
Repo
::
with_revision
(
args
.model_id
.unwrap_or_else
(||
args
.which
.model_id
()
.to_string
()),
RepoType
::
Model
,
args
.revision
.unwrap_or_else
(||
args
.which
.revision
()
.to_string
()),
));
let
tokenizer
=
match
args
.tokenizer
{
Some
(
file
)
=>
std
::
path
::
PathBuf
::
from
(
file
),
None
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"rwkv_vocab_v20230424.json"
)
?
,
};
let
config_filename
=
match
args
.config_file
{
Some
(
file
)
=>
std
::
path
::
PathBuf
::
from
(
file
),
None
=>
repo
.get
(
"config.json"
)
?
,
};
let
filenames
=
match
args
.weight_files
{
Some
(
files
)
=>
files
.split
(
','
)
.map
(
std
::
path
::
PathBuf
::
from
)
.collect
::
<
Vec
<
_
>>
(),
None
=>
{
if
args
.quantized
{
vec!
[
match
args
.which
{
Which
::
World1b5
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"world1b5-q4k.gguf"
)
?
,
Which
::
World3b
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"world3b-q4k.gguf"
)
?
,
Which
::
Eagle7b
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"eagle7b-q4k.gguf"
)
?
,
Which
::
World6_1b6
=>
repo
.get
(
"rwkv-6-world-1b6-q4k.gguf"
)
?
,
}]
}
else
{
vec!
[
match
args
.which
{
Which
::
World1b5
|
Which
::
World3b
|
Which
::
Eagle7b
=>
{
repo
.get
(
"model.safetensors"
)
?
}
Which
::
World6_1b6
=>
repo
.get
(
"rwkv-6-world-1b6.safetensors"
)
?
,
}]
}
}
};
println!
(
"retrieved the files in {:?}"
,
start
.elapsed
());
let
tokenizer
=
Tokenizer
::
new
(
tokenizer
)
?
;
let
start
=
std
::
time
::
Instant
::
now
();
let
config
:
Config
=
serde_json
::
from_slice
(
&
std
::
fs
::
read
(
config_filename
)
?
)
?
;
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
model
=
if
args
.quantized
{
let
filename
=
&
filenames
[
0
];
let
vb
=
candle_transformers
::
quantized_var_builder
::
VarBuilder
::
from_gguf
(
filename
,
&
device
)
?
;
match
args
.which
{
Which
::
World1b5
|
Which
::
World3b
|
Which
::
Eagle7b
=>
Model
::
Q5
(
Q5
::
new
(
&
config
,
vb
)
?
),
Which
::
World6_1b6
=>
Model
::
Q6
(
Q6
::
new
(
&
config
,
vb
)
?
),
}
}
else
{
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
filenames
,
DType
::
F32
,
&
device
)
?
};
match
args
.which
{
Which
::
World1b5
|
Which
::
World3b
|
Which
::
Eagle7b
=>
Model
::
M5
(
M5
::
new
(
&
config
,
vb
)
?
),
Which
::
World6_1b6
=>
Model
::
M6
(
M6
::
new
(
&
config
,
vb
)
?
),
}
};
println!
(
"loaded the model in {:?}"
,
start
.elapsed
());
let
mut
pipeline
=
TextGeneration
::
new
(
model
,
config
,
tokenizer
,
args
.seed
,
args
.temperature
,
args
.top_p
,
args
.repeat_penalty
,
args
.repeat_last_n
,
&
device
,
);
pipeline
.run
(
&
args
.prompt
,
args
.sample_len
)
?
;
Ok
(())
}
candle-examples/examples/segformer/README.md
0 → 100644
View file @
25d2752f
# candle-segformer
-
[
HuggingFace Segformer Model Card
][
segformer
]
-
[
`mit-b0` - An encoder only pretrained model
][
encoder
]
-
[
`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation
][
ade512
]
## How to run the example
If you want you can use the example images from this
[
pull request
][
pr
]
, download them and supply the path to the image as an argument to the example.
```
bash
# run the image classification task
cargo run
--example
segformer classify <path-to-image>
# run the segmentation task
cargo run
--example
segformer segment <path-to-image>
```
Example output for classification:
```
text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
[
pr
]:
https://github.com/huggingface/candle/pull/1617
[
segformer
]:
https://huggingface.co/docs/transformers/model_doc/segformer
[
encoder
]:
https://huggingface.co/nvidia/mit-b0
[
ade512
]:
https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
candle-examples/examples/segformer/assets/labels.json
0 → 100644
View file @
25d2752f
[
{
"index"
:
1
,
"color"
:
"#787878"
,
"label"
:
"wall"
},
{
"index"
:
2
,
"color"
:
"#B47878"
,
"label"
:
"building;edifice"
},
{
"index"
:
3
,
"color"
:
"#06E6E6"
,
"label"
:
"sky"
},
{
"index"
:
4
,
"color"
:
"#503232"
,
"label"
:
"floor;flooring"
},
{
"index"
:
5
,
"color"
:
"#04C803"
,
"label"
:
"tree"
},
{
"index"
:
6
,
"color"
:
"#787850"
,
"label"
:
"ceiling"
},
{
"index"
:
7
,
"color"
:
"#8C8C8C"
,
"label"
:
"road;route"
},
{
"index"
:
8
,
"color"
:
"#CC05FF"
,
"label"
:
"bed"
},
{
"index"
:
9
,
"color"
:
"#E6E6E6"
,
"label"
:
"windowpane;window"
},
{
"index"
:
10
,
"color"
:
"#04FA07"
,
"label"
:
"grass"
},
{
"index"
:
11
,
"color"
:
"#E005FF"
,
"label"
:
"cabinet"
},
{
"index"
:
12
,
"color"
:
"#EBFF07"
,
"label"
:
"sidewalk;pavement"
},
{
"index"
:
13
,
"color"
:
"#96053D"
,
"label"
:
"person;individual;someone;somebody;mortal;soul"
},
{
"index"
:
14
,
"color"
:
"#787846"
,
"label"
:
"earth;ground"
},
{
"index"
:
15
,
"color"
:
"#08FF33"
,
"label"
:
"door;double;door"
},
{
"index"
:
16
,
"color"
:
"#FF0652"
,
"label"
:
"table"
},
{
"index"
:
17
,
"color"
:
"#8FFF8C"
,
"label"
:
"mountain;mount"
},
{
"index"
:
18
,
"color"
:
"#CCFF04"
,
"label"
:
"plant;flora;plant;life"
},
{
"index"
:
19
,
"color"
:
"#FF3307"
,
"label"
:
"curtain;drape;drapery;mantle;pall"
},
{
"index"
:
20
,
"color"
:
"#CC4603"
,
"label"
:
"chair"
},
{
"index"
:
21
,
"color"
:
"#0066C8"
,
"label"
:
"car;auto;automobile;machine;motorcar"
},
{
"index"
:
22
,
"color"
:
"#3DE6FA"
,
"label"
:
"water"
},
{
"index"
:
23
,
"color"
:
"#FF0633"
,
"label"
:
"painting;picture"
},
{
"index"
:
24
,
"color"
:
"#0B66FF"
,
"label"
:
"sofa;couch;lounge"
},
{
"index"
:
25
,
"color"
:
"#FF0747"
,
"label"
:
"shelf"
},
{
"index"
:
26
,
"color"
:
"#FF09E0"
,
"label"
:
"house"
},
{
"index"
:
27
,
"color"
:
"#0907E6"
,
"label"
:
"sea"
},
{
"index"
:
28
,
"color"
:
"#DCDCDC"
,
"label"
:
"mirror"
},
{
"index"
:
29
,
"color"
:
"#FF095C"
,
"label"
:
"rug;carpet;carpeting"
},
{
"index"
:
30
,
"color"
:
"#7009FF"
,
"label"
:
"field"
},
{
"index"
:
31
,
"color"
:
"#08FFD6"
,
"label"
:
"armchair"
},
{
"index"
:
32
,
"color"
:
"#07FFE0"
,
"label"
:
"seat"
},
{
"index"
:
33
,
"color"
:
"#FFB806"
,
"label"
:
"fence;fencing"
},
{
"index"
:
34
,
"color"
:
"#0AFF47"
,
"label"
:
"desk"
},
{
"index"
:
35
,
"color"
:
"#FF290A"
,
"label"
:
"rock;stone"
},
{
"index"
:
36
,
"color"
:
"#07FFFF"
,
"label"
:
"wardrobe;closet;press"
},
{
"index"
:
37
,
"color"
:
"#E0FF08"
,
"label"
:
"lamp"
},
{
"index"
:
38
,
"color"
:
"#6608FF"
,
"label"
:
"bathtub;bathing;tub;bath;tub"
},
{
"index"
:
39
,
"color"
:
"#FF3D06"
,
"label"
:
"railing;rail"
},
{
"index"
:
40
,
"color"
:
"#FFC207"
,
"label"
:
"cushion"
},
{
"index"
:
41
,
"color"
:
"#FF7A08"
,
"label"
:
"base;pedestal;stand"
},
{
"index"
:
42
,
"color"
:
"#00FF14"
,
"label"
:
"box"
},
{
"index"
:
43
,
"color"
:
"#FF0829"
,
"label"
:
"column;pillar"
},
{
"index"
:
44
,
"color"
:
"#FF0599"
,
"label"
:
"signboard;sign"
},
{
"index"
:
45
,
"color"
:
"#0633FF"
,
"label"
:
"chest;of;drawers;chest;bureau;dresser"
},
{
"index"
:
46
,
"color"
:
"#EB0CFF"
,
"label"
:
"counter"
},
{
"index"
:
47
,
"color"
:
"#A09614"
,
"label"
:
"sand"
},
{
"index"
:
48
,
"color"
:
"#00A3FF"
,
"label"
:
"sink"
},
{
"index"
:
49
,
"color"
:
"#8C8C8C"
,
"label"
:
"skyscraper"
},
{
"index"
:
50
,
"color"
:
"#FA0A0F"
,
"label"
:
"fireplace;hearth;open;fireplace"
},
{
"index"
:
51
,
"color"
:
"#14FF00"
,
"label"
:
"refrigerator;icebox"
},
{
"index"
:
52
,
"color"
:
"#1FFF00"
,
"label"
:
"grandstand;covered;stand"
},
{
"index"
:
53
,
"color"
:
"#FF1F00"
,
"label"
:
"path"
},
{
"index"
:
54
,
"color"
:
"#FFE000"
,
"label"
:
"stairs;steps"
},
{
"index"
:
55
,
"color"
:
"#99FF00"
,
"label"
:
"runway"
},
{
"index"
:
56
,
"color"
:
"#0000FF"
,
"label"
:
"case;display;case;showcase;vitrine"
},
{
"index"
:
57
,
"color"
:
"#FF4700"
,
"label"
:
"pool;table;billiard;table;snooker;table"
},
{
"index"
:
58
,
"color"
:
"#00EBFF"
,
"label"
:
"pillow"
},
{
"index"
:
59
,
"color"
:
"#00ADFF"
,
"label"
:
"screen;door;screen"
},
{
"index"
:
60
,
"color"
:
"#1F00FF"
,
"label"
:
"stairway;staircase"
},
{
"index"
:
61
,
"color"
:
"#0BC8C8"
,
"label"
:
"river"
},
{
"index"
:
62
,
"color"
:
"#FF5200"
,
"label"
:
"bridge;span"
},
{
"index"
:
63
,
"color"
:
"#00FFF5"
,
"label"
:
"bookcase"
},
{
"index"
:
64
,
"color"
:
"#003DFF"
,
"label"
:
"blind;screen"
},
{
"index"
:
65
,
"color"
:
"#00FF70"
,
"label"
:
"coffee;table;cocktail;table"
},
{
"index"
:
66
,
"color"
:
"#00FF85"
,
"label"
:
"toilet;can;commode;crapper;pot;potty;stool;throne"
},
{
"index"
:
67
,
"color"
:
"#FF0000"
,
"label"
:
"flower"
},
{
"index"
:
68
,
"color"
:
"#FFA300"
,
"label"
:
"book"
},
{
"index"
:
69
,
"color"
:
"#FF6600"
,
"label"
:
"hill"
},
{
"index"
:
70
,
"color"
:
"#C2FF00"
,
"label"
:
"bench"
},
{
"index"
:
71
,
"color"
:
"#008FFF"
,
"label"
:
"countertop"
},
{
"index"
:
72
,
"color"
:
"#33FF00"
,
"label"
:
"stove;kitchen;stove;range;kitchen;range;cooking;stove"
},
{
"index"
:
73
,
"color"
:
"#0052FF"
,
"label"
:
"palm;palm;tree"
},
{
"index"
:
74
,
"color"
:
"#00FF29"
,
"label"
:
"kitchen;island"
},
{
"index"
:
75
,
"color"
:
"#00FFAD"
,
"label"
:
"computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
},
{
"index"
:
76
,
"color"
:
"#0A00FF"
,
"label"
:
"swivel;chair"
},
{
"index"
:
77
,
"color"
:
"#ADFF00"
,
"label"
:
"boat"
},
{
"index"
:
78
,
"color"
:
"#00FF99"
,
"label"
:
"bar"
},
{
"index"
:
79
,
"color"
:
"#FF5C00"
,
"label"
:
"arcade;machine"
},
{
"index"
:
80
,
"color"
:
"#FF00FF"
,
"label"
:
"hovel;hut;hutch;shack;shanty"
},
{
"index"
:
81
,
"color"
:
"#FF00F5"
,
"label"
:
"bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
},
{
"index"
:
82
,
"color"
:
"#FF0066"
,
"label"
:
"towel"
},
{
"index"
:
83
,
"color"
:
"#FFAD00"
,
"label"
:
"light;light;source"
},
{
"index"
:
84
,
"color"
:
"#FF0014"
,
"label"
:
"truck;motortruck"
},
{
"index"
:
85
,
"color"
:
"#FFB8B8"
,
"label"
:
"tower"
},
{
"index"
:
86
,
"color"
:
"#001FFF"
,
"label"
:
"chandelier;pendant;pendent"
},
{
"index"
:
87
,
"color"
:
"#00FF3D"
,
"label"
:
"awning;sunshade;sunblind"
},
{
"index"
:
88
,
"color"
:
"#0047FF"
,
"label"
:
"streetlight;street;lamp"
},
{
"index"
:
89
,
"color"
:
"#FF00CC"
,
"label"
:
"booth;cubicle;stall;kiosk"
},
{
"index"
:
90
,
"color"
:
"#00FFC2"
,
"label"
:
"television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
},
{
"index"
:
91
,
"color"
:
"#00FF52"
,
"label"
:
"airplane;aeroplane;plane"
},
{
"index"
:
92
,
"color"
:
"#000AFF"
,
"label"
:
"dirt;track"
},
{
"index"
:
93
,
"color"
:
"#0070FF"
,
"label"
:
"apparel;wearing;apparel;dress;clothes"
},
{
"index"
:
94
,
"color"
:
"#3300FF"
,
"label"
:
"pole"
},
{
"index"
:
95
,
"color"
:
"#00C2FF"
,
"label"
:
"land;ground;soil"
},
{
"index"
:
96
,
"color"
:
"#007AFF"
,
"label"
:
"bannister;banister;balustrade;balusters;handrail"
},
{
"index"
:
97
,
"color"
:
"#00FFA3"
,
"label"
:
"escalator;moving;staircase;moving;stairway"
},
{
"index"
:
98
,
"color"
:
"#FF9900"
,
"label"
:
"ottoman;pouf;pouffe;puff;hassock"
},
{
"index"
:
99
,
"color"
:
"#00FF0A"
,
"label"
:
"bottle"
},
{
"index"
:
100
,
"color"
:
"#FF7000"
,
"label"
:
"buffet;counter;sideboard"
},
{
"index"
:
101
,
"color"
:
"#8FFF00"
,
"label"
:
"poster;posting;placard;notice;bill;card"
},
{
"index"
:
102
,
"color"
:
"#5200FF"
,
"label"
:
"stage"
},
{
"index"
:
103
,
"color"
:
"#A3FF00"
,
"label"
:
"van"
},
{
"index"
:
104
,
"color"
:
"#FFEB00"
,
"label"
:
"ship"
},
{
"index"
:
105
,
"color"
:
"#08B8AA"
,
"label"
:
"fountain"
},
{
"index"
:
106
,
"color"
:
"#8500FF"
,
"label"
:
"conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
},
{
"index"
:
107
,
"color"
:
"#00FF5C"
,
"label"
:
"canopy"
},
{
"index"
:
108
,
"color"
:
"#B800FF"
,
"label"
:
"washer;automatic;washer;washing;machine"
},
{
"index"
:
109
,
"color"
:
"#FF001F"
,
"label"
:
"plaything;toy"
},
{
"index"
:
110
,
"color"
:
"#00B8FF"
,
"label"
:
"swimming;pool;swimming;bath;natatorium"
},
{
"index"
:
111
,
"color"
:
"#00D6FF"
,
"label"
:
"stool"
},
{
"index"
:
112
,
"color"
:
"#FF0070"
,
"label"
:
"barrel;cask"
},
{
"index"
:
113
,
"color"
:
"#5CFF00"
,
"label"
:
"basket;handbasket"
},
{
"index"
:
114
,
"color"
:
"#00E0FF"
,
"label"
:
"waterfall;falls"
},
{
"index"
:
115
,
"color"
:
"#70E0FF"
,
"label"
:
"tent;collapsible;shelter"
},
{
"index"
:
116
,
"color"
:
"#46B8A0"
,
"label"
:
"bag"
},
{
"index"
:
117
,
"color"
:
"#A300FF"
,
"label"
:
"minibike;motorbike"
},
{
"index"
:
118
,
"color"
:
"#9900FF"
,
"label"
:
"cradle"
},
{
"index"
:
119
,
"color"
:
"#47FF00"
,
"label"
:
"oven"
},
{
"index"
:
120
,
"color"
:
"#FF00A3"
,
"label"
:
"ball"
},
{
"index"
:
121
,
"color"
:
"#FFCC00"
,
"label"
:
"food;solid;food"
},
{
"index"
:
122
,
"color"
:
"#FF008F"
,
"label"
:
"step;stair"
},
{
"index"
:
123
,
"color"
:
"#00FFEB"
,
"label"
:
"tank;storage;tank"
},
{
"index"
:
124
,
"color"
:
"#85FF00"
,
"label"
:
"trade;name;brand;name;brand;marque"
},
{
"index"
:
125
,
"color"
:
"#FF00EB"
,
"label"
:
"microwave;microwave;oven"
},
{
"index"
:
126
,
"color"
:
"#F500FF"
,
"label"
:
"pot;flowerpot"
},
{
"index"
:
127
,
"color"
:
"#FF007A"
,
"label"
:
"animal;animate;being;beast;brute;creature;fauna"
},
{
"index"
:
128
,
"color"
:
"#FFF500"
,
"label"
:
"bicycle;bike;wheel;cycle"
},
{
"index"
:
129
,
"color"
:
"#0ABED4"
,
"label"
:
"lake"
},
{
"index"
:
130
,
"color"
:
"#D6FF00"
,
"label"
:
"dishwasher;dish;washer;dishwashing;machine"
},
{
"index"
:
131
,
"color"
:
"#00CCFF"
,
"label"
:
"screen;silver;screen;projection;screen"
},
{
"index"
:
132
,
"color"
:
"#1400FF"
,
"label"
:
"blanket;cover"
},
{
"index"
:
133
,
"color"
:
"#FFFF00"
,
"label"
:
"sculpture"
},
{
"index"
:
134
,
"color"
:
"#0099FF"
,
"label"
:
"hood;exhaust;hood"
},
{
"index"
:
135
,
"color"
:
"#0029FF"
,
"label"
:
"sconce"
},
{
"index"
:
136
,
"color"
:
"#00FFCC"
,
"label"
:
"vase"
},
{
"index"
:
137
,
"color"
:
"#2900FF"
,
"label"
:
"traffic;light;traffic;signal;stoplight"
},
{
"index"
:
138
,
"color"
:
"#29FF00"
,
"label"
:
"tray"
},
{
"index"
:
139
,
"color"
:
"#AD00FF"
,
"label"
:
"ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
},
{
"index"
:
140
,
"color"
:
"#00F5FF"
,
"label"
:
"fan"
},
{
"index"
:
141
,
"color"
:
"#4700FF"
,
"label"
:
"pier;wharf;wharfage;dock"
},
{
"index"
:
142
,
"color"
:
"#7A00FF"
,
"label"
:
"crt;screen"
},
{
"index"
:
143
,
"color"
:
"#00FFB8"
,
"label"
:
"plate"
},
{
"index"
:
144
,
"color"
:
"#005CFF"
,
"label"
:
"monitor;monitoring;device"
},
{
"index"
:
145
,
"color"
:
"#B8FF00"
,
"label"
:
"bulletin;board;notice;board"
},
{
"index"
:
146
,
"color"
:
"#0085FF"
,
"label"
:
"shower"
},
{
"index"
:
147
,
"color"
:
"#FFD600"
,
"label"
:
"radiator"
},
{
"index"
:
148
,
"color"
:
"#19C2C2"
,
"label"
:
"glass;drinking;glass"
},
{
"index"
:
149
,
"color"
:
"#66FF00"
,
"label"
:
"clock"
},
{
"index"
:
150
,
"color"
:
"#5C00FF"
,
"label"
:
"flag"
}
]
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment