Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
yongshk
candle
Commits
25d2752f
Commit
25d2752f
authored
May 29, 2025
by
yongshk
Browse files
Initial commit
parents
Changes
238
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
3058 additions
and
0 deletions
+3058
-0
candle-examples/examples/reinforcement-learning/atari_wrappers.py
...xamples/examples/reinforcement-learning/atari_wrappers.py
+308
-0
candle-examples/examples/reinforcement-learning/ddpg.rs
candle-examples/examples/reinforcement-learning/ddpg.rs
+556
-0
candle-examples/examples/reinforcement-learning/dqn.rs
candle-examples/examples/reinforcement-learning/dqn.rs
+118
-0
candle-examples/examples/reinforcement-learning/gym_env.rs
candle-examples/examples/reinforcement-learning/gym_env.rs
+114
-0
candle-examples/examples/reinforcement-learning/main.rs
candle-examples/examples/reinforcement-learning/main.rs
+40
-0
candle-examples/examples/reinforcement-learning/policy_gradient.rs
...amples/examples/reinforcement-learning/policy_gradient.rs
+146
-0
candle-examples/examples/reinforcement-learning/vec_gym_env.rs
...e-examples/examples/reinforcement-learning/vec_gym_env.rs
+91
-0
candle-examples/examples/replit-code/README.md
candle-examples/examples/replit-code/README.md
+40
-0
candle-examples/examples/replit-code/main.rs
candle-examples/examples/replit-code/main.rs
+264
-0
candle-examples/examples/repvgg/README.md
candle-examples/examples/repvgg/README.md
+22
-0
candle-examples/examples/repvgg/main.rs
candle-examples/examples/repvgg/main.rs
+111
-0
candle-examples/examples/resnet/README.md
candle-examples/examples/resnet/README.md
+19
-0
candle-examples/examples/resnet/export_models.py
candle-examples/examples/resnet/export_models.py
+12
-0
candle-examples/examples/resnet/main.rs
candle-examples/examples/resnet/main.rs
+90
-0
candle-examples/examples/rwkv/README.md
candle-examples/examples/rwkv/README.md
+17
-0
candle-examples/examples/rwkv/main.rs
candle-examples/examples/rwkv/main.rs
+330
-0
candle-examples/examples/segformer/README.md
candle-examples/examples/segformer/README.md
+28
-0
candle-examples/examples/segformer/assets/labels.json
candle-examples/examples/segformer/assets/labels.json
+752
-0
No files found.
Too many changes to show.
To preserve performance only
238 of 238+
files are displayed.
Plain diff
Email patch
candle-examples/examples/reinforcement-learning/atari_wrappers.py
0 → 100644
View file @
25d2752f
import
gymnasium
as
gym
import
numpy
as
np
from
collections
import
deque
from
PIL
import
Image
from
multiprocessing
import
Process
,
Pipe
# atari_wrappers.py
class
NoopResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
noop_max
=
30
):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
noop_max
=
noop_max
self
.
override_num_noops
=
None
assert
env
.
unwrapped
.
get_action_meanings
()[
0
]
==
'NOOP'
def
reset
(
self
):
""" Do no-op action for a number of steps in [1, noop_max]."""
self
.
env
.
reset
()
if
self
.
override_num_noops
is
not
None
:
noops
=
self
.
override_num_noops
else
:
noops
=
self
.
unwrapped
.
np_random
.
integers
(
1
,
self
.
noop_max
+
1
)
#pylint: disable=E1101
assert
noops
>
0
obs
=
None
for
_
in
range
(
noops
):
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
0
)
if
done
:
obs
=
self
.
env
.
reset
()
return
obs
class
FireResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Take action on reset for environments that are fixed until firing."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
assert
env
.
unwrapped
.
get_action_meanings
()[
1
]
==
'FIRE'
assert
len
(
env
.
unwrapped
.
get_action_meanings
())
>=
3
def
reset
(
self
):
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
1
)
if
done
:
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
2
)
if
done
:
self
.
env
.
reset
()
return
obs
class
ImageSaver
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
img_path
,
rank
):
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
_cnt
=
0
self
.
_img_path
=
img_path
self
.
_rank
=
rank
def
step
(
self
,
action
):
step_result
=
self
.
env
.
step
(
action
)
obs
,
_
,
_
,
_
=
step_result
img
=
Image
.
fromarray
(
obs
,
'RGB'
)
img
.
save
(
'%s/out%d-%05d.png'
%
(
self
.
_img_path
,
self
.
_rank
,
self
.
_cnt
))
self
.
_cnt
+=
1
return
step_result
class
EpisodicLifeEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
lives
=
0
self
.
was_real_done
=
True
def
step
(
self
,
action
):
obs
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
was_real_done
=
done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives
=
self
.
env
.
unwrapped
.
ale
.
lives
()
if
lives
<
self
.
lives
and
lives
>
0
:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done
=
True
self
.
lives
=
lives
return
obs
,
reward
,
done
,
info
def
reset
(
self
):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if
self
.
was_real_done
:
obs
=
self
.
env
.
reset
()
else
:
# no-op step to advance from terminal/lost life state
obs
,
_
,
_
,
_
=
self
.
env
.
step
(
0
)
self
.
lives
=
self
.
env
.
unwrapped
.
ale
.
lives
()
return
obs
class
MaxAndSkipEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
skip
=
4
):
"""Return only every `skip`-th frame"""
gym
.
Wrapper
.
__init__
(
self
,
env
)
# most recent raw observations (for max pooling across time steps)
self
.
_obs_buffer
=
deque
(
maxlen
=
2
)
self
.
_skip
=
skip
def
step
(
self
,
action
):
"""Repeat action, sum reward, and max over last observations."""
total_reward
=
0.0
done
=
None
for
_
in
range
(
self
.
_skip
):
obs
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
_obs_buffer
.
append
(
obs
)
total_reward
+=
reward
if
done
:
break
max_frame
=
np
.
max
(
np
.
stack
(
self
.
_obs_buffer
),
axis
=
0
)
return
max_frame
,
total_reward
,
done
,
info
def
reset
(
self
):
"""Clear past frame buffer and init. to first obs. from inner env."""
self
.
_obs_buffer
.
clear
()
obs
=
self
.
env
.
reset
()
self
.
_obs_buffer
.
append
(
obs
)
return
obs
class
ClipRewardEnv
(
gym
.
RewardWrapper
):
def
reward
(
self
,
reward
):
"""Bin reward to {+1, 0, -1} by its sign."""
return
np
.
sign
(
reward
)
class
WarpFrame
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
):
"""Warp frames to 84x84 as done in the Nature paper and later work."""
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
res
=
84
self
.
observation_space
=
gym
.
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
self
.
res
,
self
.
res
,
1
),
dtype
=
'uint8'
)
def
observation
(
self
,
obs
):
frame
=
np
.
dot
(
obs
.
astype
(
'float32'
),
np
.
array
([
0.299
,
0.587
,
0.114
],
'float32'
))
frame
=
np
.
array
(
Image
.
fromarray
(
frame
).
resize
((
self
.
res
,
self
.
res
),
resample
=
Image
.
BILINEAR
),
dtype
=
np
.
uint8
)
return
frame
.
reshape
((
self
.
res
,
self
.
res
,
1
))
class
FrameStack
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
"""Buffer observations and stack across channels (last axis)."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
assert
shp
[
2
]
==
1
# can only stack 1-channel frames
self
.
observation_space
=
gym
.
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
k
),
dtype
=
'uint8'
)
def
reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
ob
=
self
.
env
.
reset
()
for
_
in
range
(
self
.
k
):
self
.
frames
.
append
(
ob
)
return
self
.
observation
()
def
step
(
self
,
action
):
ob
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
frames
.
append
(
ob
)
return
self
.
observation
(),
reward
,
done
,
info
def
observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
return
np
.
concatenate
(
self
.
frames
,
axis
=
2
)
def
wrap_deepmind
(
env
,
episode_life
=
True
,
clip_rewards
=
True
):
"""Configure environment for DeepMind-style Atari.
Note: this does not include frame stacking!"""
assert
'NoFrameskip'
in
env
.
spec
.
id
# required for DeepMind-style skip
if
episode_life
:
env
=
EpisodicLifeEnv
(
env
)
env
=
NoopResetEnv
(
env
,
noop_max
=
30
)
env
=
MaxAndSkipEnv
(
env
,
skip
=
4
)
if
'FIRE'
in
env
.
unwrapped
.
get_action_meanings
():
env
=
FireResetEnv
(
env
)
env
=
WarpFrame
(
env
)
if
clip_rewards
:
env
=
ClipRewardEnv
(
env
)
return
env
# envs.py
def
make_env
(
env_id
,
img_dir
,
seed
,
rank
):
def
_thunk
():
env
=
gym
.
make
(
env_id
)
env
.
reset
(
seed
=
(
seed
+
rank
))
if
img_dir
is
not
None
:
env
=
ImageSaver
(
env
,
img_dir
,
rank
)
env
=
wrap_deepmind
(
env
)
env
=
WrapPyTorch
(
env
)
return
env
return
_thunk
class
WrapPyTorch
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
=
None
):
super
(
WrapPyTorch
,
self
).
__init__
(
env
)
self
.
observation_space
=
gym
.
spaces
.
Box
(
0.0
,
1.0
,
[
1
,
84
,
84
],
dtype
=
'float32'
)
def
observation
(
self
,
observation
):
return
observation
.
transpose
(
2
,
0
,
1
)
# vecenv.py
class
VecEnv
(
object
):
"""
Vectorized environment base class
"""
def
step
(
self
,
vac
):
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, news)
where 'news' is a boolean vector indicating whether each element is new.
"""
raise
NotImplementedError
def
reset
(
self
):
"""
Reset all environments
"""
raise
NotImplementedError
def
close
(
self
):
pass
# subproc_vec_env.py
def
worker
(
remote
,
env_fn_wrapper
):
env
=
env_fn_wrapper
.
x
()
while
True
:
cmd
,
data
=
remote
.
recv
()
if
cmd
==
'step'
:
ob
,
reward
,
done
,
info
=
env
.
step
(
data
)
if
done
:
ob
=
env
.
reset
()
remote
.
send
((
ob
,
reward
,
done
,
info
))
elif
cmd
==
'reset'
:
ob
=
env
.
reset
()
remote
.
send
(
ob
)
elif
cmd
==
'close'
:
remote
.
close
()
break
elif
cmd
==
'get_spaces'
:
remote
.
send
((
env
.
action_space
,
env
.
observation_space
))
else
:
raise
NotImplementedError
class
CloudpickleWrapper
(
object
):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def
__init__
(
self
,
x
):
self
.
x
=
x
def
__getstate__
(
self
):
import
cloudpickle
return
cloudpickle
.
dumps
(
self
.
x
)
def
__setstate__
(
self
,
ob
):
import
pickle
self
.
x
=
pickle
.
loads
(
ob
)
class
SubprocVecEnv
(
VecEnv
):
def
__init__
(
self
,
env_fns
):
"""
envs: list of gym environments to run in subprocesses
"""
nenvs
=
len
(
env_fns
)
self
.
remotes
,
self
.
work_remotes
=
zip
(
*
[
Pipe
()
for
_
in
range
(
nenvs
)])
self
.
ps
=
[
Process
(
target
=
worker
,
args
=
(
work_remote
,
CloudpickleWrapper
(
env_fn
)))
for
(
work_remote
,
env_fn
)
in
zip
(
self
.
work_remotes
,
env_fns
)]
for
p
in
self
.
ps
:
p
.
start
()
self
.
remotes
[
0
].
send
((
'get_spaces'
,
None
))
self
.
action_space
,
self
.
observation_space
=
self
.
remotes
[
0
].
recv
()
def
step
(
self
,
actions
):
for
remote
,
action
in
zip
(
self
.
remotes
,
actions
):
remote
.
send
((
'step'
,
action
))
results
=
[
remote
.
recv
()
for
remote
in
self
.
remotes
]
obs
,
rews
,
dones
,
infos
=
zip
(
*
results
)
return
np
.
stack
(
obs
),
np
.
stack
(
rews
),
np
.
stack
(
dones
),
infos
def
reset
(
self
):
for
remote
in
self
.
remotes
:
remote
.
send
((
'reset'
,
None
))
return
np
.
stack
([
remote
.
recv
()
for
remote
in
self
.
remotes
])
def
close
(
self
):
for
remote
in
self
.
remotes
:
remote
.
send
((
'close'
,
None
))
for
p
in
self
.
ps
:
p
.
join
()
@
property
def
num_envs
(
self
):
return
len
(
self
.
remotes
)
# Create the environment.
def
make
(
env_name
,
img_dir
,
num_processes
):
envs
=
SubprocVecEnv
([
make_env
(
env_name
,
img_dir
,
1337
,
i
)
for
i
in
range
(
num_processes
)
])
return
envs
candle-examples/examples/reinforcement-learning/ddpg.rs
0 → 100644
View file @
25d2752f
use
std
::
collections
::
VecDeque
;
use
std
::
fmt
::
Display
;
use
candle
::{
DType
,
Device
,
Error
,
Module
,
Result
,
Tensor
,
Var
};
use
candle_nn
::{
func
,
linear
,
sequential
::
seq
,
Activation
,
AdamW
,
Optimizer
,
ParamsAdamW
,
Sequential
,
VarBuilder
,
VarMap
,
};
use
rand
::{
distributions
::
Uniform
,
thread_rng
,
Rng
};
use
super
::
gym_env
::
GymEnv
;
pub
struct
OuNoise
{
mu
:
f64
,
theta
:
f64
,
sigma
:
f64
,
state
:
Tensor
,
}
impl
OuNoise
{
pub
fn
new
(
mu
:
f64
,
theta
:
f64
,
sigma
:
f64
,
size_action
:
usize
)
->
Result
<
Self
>
{
Ok
(
Self
{
mu
,
theta
,
sigma
,
state
:
Tensor
::
ones
(
size_action
,
DType
::
F32
,
&
Device
::
Cpu
)
?
,
})
}
pub
fn
sample
(
&
mut
self
)
->
Result
<
Tensor
>
{
let
rand
=
Tensor
::
randn_like
(
&
self
.state
,
0.0
,
1.0
)
?
;
let
dx
=
((
self
.theta
*
(
self
.mu
-
&
self
.state
)
?
)
?
+
(
self
.sigma
*
rand
)
?
)
?
;
self
.state
=
(
&
self
.state
+
dx
)
?
;
Ok
(
self
.state
.clone
())
}
}
#[derive(Clone)]
struct
Transition
{
state
:
Tensor
,
action
:
Tensor
,
reward
:
Tensor
,
next_state
:
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
}
impl
Transition
{
fn
new
(
state
:
&
Tensor
,
action
:
&
Tensor
,
reward
:
&
Tensor
,
next_state
:
&
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
)
->
Self
{
Self
{
state
:
state
.clone
(),
action
:
action
.clone
(),
reward
:
reward
.clone
(),
next_state
:
next_state
.clone
(),
terminated
,
truncated
,
}
}
}
pub
struct
ReplayBuffer
{
buffer
:
VecDeque
<
Transition
>
,
capacity
:
usize
,
size
:
usize
,
}
impl
ReplayBuffer
{
pub
fn
new
(
capacity
:
usize
)
->
Self
{
Self
{
buffer
:
VecDeque
::
with_capacity
(
capacity
),
capacity
,
size
:
0
,
}
}
pub
fn
push
(
&
mut
self
,
state
:
&
Tensor
,
action
:
&
Tensor
,
reward
:
&
Tensor
,
next_state
:
&
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
)
{
if
self
.size
==
self
.capacity
{
self
.buffer
.pop_front
();
}
else
{
self
.size
+=
1
;
}
self
.buffer
.push_back
(
Transition
::
new
(
state
,
action
,
reward
,
next_state
,
terminated
,
truncated
,
));
}
#[allow(clippy::type_complexity)]
pub
fn
random_batch
(
&
self
,
batch_size
:
usize
,
)
->
Result
<
Option
<
(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Vec
<
bool
>
,
Vec
<
bool
>
)
>>
{
if
self
.size
<
batch_size
{
Ok
(
None
)
}
else
{
let
transitions
:
Vec
<&
Transition
>
=
thread_rng
()
.sample_iter
(
Uniform
::
from
(
0
..
self
.size
))
.take
(
batch_size
)
.map
(|
i
|
self
.buffer
.get
(
i
)
.unwrap
())
.collect
();
let
states
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.state
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
actions
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.action
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
rewards
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.reward
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
next_states
:
Vec
<
Tensor
>
=
transitions
.iter
()
.map
(|
t
|
t
.next_state
.unsqueeze
(
0
))
.collect
::
<
Result
<
_
>>
()
?
;
let
terminateds
:
Vec
<
bool
>
=
transitions
.iter
()
.map
(|
t
|
t
.terminated
)
.collect
();
let
truncateds
:
Vec
<
bool
>
=
transitions
.iter
()
.map
(|
t
|
t
.truncated
)
.collect
();
Ok
(
Some
((
Tensor
::
cat
(
&
states
,
0
)
?
,
Tensor
::
cat
(
&
actions
,
0
)
?
,
Tensor
::
cat
(
&
rewards
,
0
)
?
,
Tensor
::
cat
(
&
next_states
,
0
)
?
,
terminateds
,
truncateds
,
)))
}
}
}
fn
track
(
varmap
:
&
mut
VarMap
,
vb
:
&
VarBuilder
,
target_prefix
:
&
str
,
network_prefix
:
&
str
,
dims
:
&
[(
usize
,
usize
)],
tau
:
f64
,
)
->
Result
<
()
>
{
for
(
i
,
&
(
in_dim
,
out_dim
))
in
dims
.iter
()
.enumerate
()
{
let
target_w
=
vb
.get
((
out_dim
,
in_dim
),
&
format!
(
"{target_prefix}-fc{i}.weight"
))
?
;
let
network_w
=
vb
.get
((
out_dim
,
in_dim
),
&
format!
(
"{network_prefix}-fc{i}.weight"
))
?
;
varmap
.set_one
(
format!
(
"{target_prefix}-fc{i}.weight"
),
((
tau
*
network_w
)
?
+
((
1.0
-
tau
)
*
target_w
)
?
)
?
,
)
?
;
let
target_b
=
vb
.get
(
out_dim
,
&
format!
(
"{target_prefix}-fc{i}.bias"
))
?
;
let
network_b
=
vb
.get
(
out_dim
,
&
format!
(
"{network_prefix}-fc{i}.bias"
))
?
;
varmap
.set_one
(
format!
(
"{target_prefix}-fc{i}.bias"
),
((
tau
*
network_b
)
?
+
((
1.0
-
tau
)
*
target_b
)
?
)
?
,
)
?
;
}
Ok
(())
}
struct
Actor
<
'a
>
{
varmap
:
VarMap
,
vb
:
VarBuilder
<
'a
>
,
network
:
Sequential
,
target_network
:
Sequential
,
size_state
:
usize
,
size_action
:
usize
,
dims
:
Vec
<
(
usize
,
usize
)
>
,
}
impl
Actor
<
'_
>
{
fn
new
(
device
:
&
Device
,
dtype
:
DType
,
size_state
:
usize
,
size_action
:
usize
)
->
Result
<
Self
>
{
let
mut
varmap
=
VarMap
::
new
();
let
vb
=
VarBuilder
::
from_varmap
(
&
varmap
,
dtype
,
device
);
let
dims
=
vec!
[(
size_state
,
400
),
(
400
,
300
),
(
300
,
size_action
)];
let
make_network
=
|
prefix
:
&
str
|
{
let
seq
=
seq
()
.add
(
linear
(
dims
[
0
]
.0
,
dims
[
0
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc0"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
1
]
.0
,
dims
[
1
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc1"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
2
]
.0
,
dims
[
2
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc2"
)),
)
?
)
.add
(
func
(|
xs
|
xs
.tanh
()));
Ok
::
<
Sequential
,
Error
>
(
seq
)
};
let
network
=
make_network
(
"actor"
)
?
;
let
target_network
=
make_network
(
"target-actor"
)
?
;
// this sets the two networks to be equal to each other using tau = 1.0
track
(
&
mut
varmap
,
&
vb
,
"target-actor"
,
"actor"
,
&
dims
,
1.0
);
Ok
(
Self
{
varmap
,
vb
,
network
,
target_network
,
size_state
,
size_action
,
dims
,
})
}
fn
forward
(
&
self
,
state
:
&
Tensor
)
->
Result
<
Tensor
>
{
self
.network
.forward
(
state
)
}
fn
target_forward
(
&
self
,
state
:
&
Tensor
)
->
Result
<
Tensor
>
{
self
.target_network
.forward
(
state
)
}
fn
track
(
&
mut
self
,
tau
:
f64
)
->
Result
<
()
>
{
track
(
&
mut
self
.varmap
,
&
self
.vb
,
"target-actor"
,
"actor"
,
&
self
.dims
,
tau
,
)
}
}
struct
Critic
<
'a
>
{
varmap
:
VarMap
,
vb
:
VarBuilder
<
'a
>
,
network
:
Sequential
,
target_network
:
Sequential
,
size_state
:
usize
,
size_action
:
usize
,
dims
:
Vec
<
(
usize
,
usize
)
>
,
}
impl
Critic
<
'_
>
{
fn
new
(
device
:
&
Device
,
dtype
:
DType
,
size_state
:
usize
,
size_action
:
usize
)
->
Result
<
Self
>
{
let
mut
varmap
=
VarMap
::
new
();
let
vb
=
VarBuilder
::
from_varmap
(
&
varmap
,
dtype
,
device
);
let
dims
:
Vec
<
(
usize
,
usize
)
>
=
vec!
[(
size_state
+
size_action
,
400
),
(
400
,
300
),
(
300
,
1
)];
let
make_network
=
|
prefix
:
&
str
|
{
let
seq
=
seq
()
.add
(
linear
(
dims
[
0
]
.0
,
dims
[
0
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc0"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
1
]
.0
,
dims
[
1
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc1"
)),
)
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
dims
[
2
]
.0
,
dims
[
2
]
.1
,
vb
.pp
(
format!
(
"{prefix}-fc2"
)),
)
?
);
Ok
::
<
Sequential
,
Error
>
(
seq
)
};
let
network
=
make_network
(
"critic"
)
?
;
let
target_network
=
make_network
(
"target-critic"
)
?
;
// this sets the two networks to be equal to each other using tau = 1.0
track
(
&
mut
varmap
,
&
vb
,
"target-critic"
,
"critic"
,
&
dims
,
1.0
);
Ok
(
Self
{
varmap
,
vb
,
network
,
target_network
,
size_state
,
size_action
,
dims
,
})
}
fn
forward
(
&
self
,
state
:
&
Tensor
,
action
:
&
Tensor
)
->
Result
<
Tensor
>
{
let
xs
=
Tensor
::
cat
(
&
[
action
,
state
],
1
)
?
;
self
.network
.forward
(
&
xs
)
}
fn
target_forward
(
&
self
,
state
:
&
Tensor
,
action
:
&
Tensor
)
->
Result
<
Tensor
>
{
let
xs
=
Tensor
::
cat
(
&
[
action
,
state
],
1
)
?
;
self
.target_network
.forward
(
&
xs
)
}
fn
track
(
&
mut
self
,
tau
:
f64
)
->
Result
<
()
>
{
track
(
&
mut
self
.varmap
,
&
self
.vb
,
"target-critic"
,
"critic"
,
&
self
.dims
,
tau
,
)
}
}
#[allow(clippy::upper_case_acronyms)]
pub
struct
DDPG
<
'a
>
{
actor
:
Actor
<
'a
>
,
actor_optim
:
AdamW
,
critic
:
Critic
<
'a
>
,
critic_optim
:
AdamW
,
gamma
:
f64
,
tau
:
f64
,
replay_buffer
:
ReplayBuffer
,
ou_noise
:
OuNoise
,
size_state
:
usize
,
size_action
:
usize
,
pub
train
:
bool
,
}
impl
DDPG
<
'_
>
{
#[allow(clippy::too_many_arguments)]
pub
fn
new
(
device
:
&
Device
,
size_state
:
usize
,
size_action
:
usize
,
train
:
bool
,
actor_lr
:
f64
,
critic_lr
:
f64
,
gamma
:
f64
,
tau
:
f64
,
buffer_capacity
:
usize
,
ou_noise
:
OuNoise
,
)
->
Result
<
Self
>
{
let
filter_by_prefix
=
|
varmap
:
&
VarMap
,
prefix
:
&
str
|
{
varmap
.data
()
.lock
()
.unwrap
()
.iter
()
.filter_map
(|(
name
,
var
)|
name
.starts_with
(
prefix
)
.then_some
(
var
.clone
()))
.collect
::
<
Vec
<
Var
>>
()
};
let
actor
=
Actor
::
new
(
device
,
DType
::
F32
,
size_state
,
size_action
)
?
;
let
actor_optim
=
AdamW
::
new
(
filter_by_prefix
(
&
actor
.varmap
,
"actor"
),
ParamsAdamW
{
lr
:
actor_lr
,
..
Default
::
default
()
},
)
?
;
let
critic
=
Critic
::
new
(
device
,
DType
::
F32
,
size_state
,
size_action
)
?
;
let
critic_optim
=
AdamW
::
new
(
filter_by_prefix
(
&
critic
.varmap
,
"critic"
),
ParamsAdamW
{
lr
:
critic_lr
,
..
Default
::
default
()
},
)
?
;
Ok
(
Self
{
actor
,
actor_optim
,
critic
,
critic_optim
,
gamma
,
tau
,
replay_buffer
:
ReplayBuffer
::
new
(
buffer_capacity
),
ou_noise
,
size_state
,
size_action
,
train
,
})
}
pub
fn
remember
(
&
mut
self
,
state
:
&
Tensor
,
action
:
&
Tensor
,
reward
:
&
Tensor
,
next_state
:
&
Tensor
,
terminated
:
bool
,
truncated
:
bool
,
)
{
self
.replay_buffer
.push
(
state
,
action
,
reward
,
next_state
,
terminated
,
truncated
)
}
pub
fn
actions
(
&
mut
self
,
state
:
&
Tensor
)
->
Result
<
f32
>
{
let
actions
=
self
.actor
.forward
(
&
state
.detach
()
.unsqueeze
(
0
)
?
)
?
.squeeze
(
0
)
?
;
let
actions
=
if
self
.train
{
(
actions
+
self
.ou_noise
.sample
()
?
)
?
}
else
{
actions
};
actions
.squeeze
(
0
)
?
.to_scalar
::
<
f32
>
()
}
pub
fn
train
(
&
mut
self
,
batch_size
:
usize
)
->
Result
<
()
>
{
let
(
states
,
actions
,
rewards
,
next_states
,
_
,
_
)
=
match
self
.replay_buffer
.random_batch
(
batch_size
)
?
{
Some
(
v
)
=>
v
,
_
=>
return
Ok
(()),
};
let
q_target
=
self
.critic
.target_forward
(
&
next_states
,
&
self
.actor
.target_forward
(
&
next_states
)
?
)
?
;
let
q_target
=
(
rewards
+
(
self
.gamma
*
q_target
)
?
.detach
())
?
;
let
q
=
self
.critic
.forward
(
&
states
,
&
actions
)
?
;
let
diff
=
(
q_target
-
q
)
?
;
let
critic_loss
=
diff
.sqr
()
?
.mean_all
()
?
;
self
.critic_optim
.backward_step
(
&
critic_loss
)
?
;
let
actor_loss
=
self
.critic
.forward
(
&
states
,
&
self
.actor
.forward
(
&
states
)
?
)
?
.mean_all
()
?
.neg
()
?
;
self
.actor_optim
.backward_step
(
&
actor_loss
)
?
;
self
.critic
.track
(
self
.tau
)
?
;
self
.actor
.track
(
self
.tau
)
?
;
Ok
(())
}
}
// The impact of the q value of the next state on the current state's q value.
const
GAMMA
:
f64
=
0.99
;
// The weight for updating the target networks.
const
TAU
:
f64
=
0.005
;
// The capacity of the replay buffer used for sampling training data.
const
REPLAY_BUFFER_CAPACITY
:
usize
=
100_000
;
// The training batch size for each training iteration.
const
TRAINING_BATCH_SIZE
:
usize
=
100
;
// The total number of episodes.
const
MAX_EPISODES
:
usize
=
100
;
// The maximum length of an episode.
const
EPISODE_LENGTH
:
usize
=
200
;
// The number of training iterations after one episode finishes.
const
TRAINING_ITERATIONS
:
usize
=
200
;
// Ornstein-Uhlenbeck process parameters.
const
MU
:
f64
=
0.0
;
const
THETA
:
f64
=
0.15
;
const
SIGMA
:
f64
=
0.1
;
const
ACTOR_LEARNING_RATE
:
f64
=
1e-4
;
const
CRITIC_LEARNING_RATE
:
f64
=
1e-3
;
pub
fn
run
()
->
Result
<
()
>
{
let
env
=
GymEnv
::
new
(
"Pendulum-v1"
)
?
;
println!
(
"action space: {}"
,
env
.action_space
());
println!
(
"observation space: {:?}"
,
env
.observation_space
());
let
size_state
=
env
.observation_space
()
.iter
()
.product
::
<
usize
>
();
let
size_action
=
env
.action_space
();
let
mut
agent
=
DDPG
::
new
(
&
Device
::
Cpu
,
size_state
,
size_action
,
true
,
ACTOR_LEARNING_RATE
,
CRITIC_LEARNING_RATE
,
GAMMA
,
TAU
,
REPLAY_BUFFER_CAPACITY
,
OuNoise
::
new
(
MU
,
THETA
,
SIGMA
,
size_action
)
?
,
)
?
;
let
mut
rng
=
rand
::
thread_rng
();
for
episode
in
0
..
MAX_EPISODES
{
// let mut state = env.reset(episode as u64)?;
let
mut
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
let
mut
total_reward
=
0.0
;
for
_
in
0
..
EPISODE_LENGTH
{
let
mut
action
=
2.0
*
agent
.actions
(
&
state
)
?
;
action
=
action
.clamp
(
-
2.0
,
2.0
);
let
step
=
env
.step
(
vec!
[
action
])
?
;
total_reward
+=
step
.reward
;
agent
.remember
(
&
state
,
&
Tensor
::
new
(
vec!
[
action
],
&
Device
::
Cpu
)
?
,
&
Tensor
::
new
(
vec!
[
step
.reward
as
f32
],
&
Device
::
Cpu
)
?
,
&
step
.state
,
step
.terminated
,
step
.truncated
,
);
if
step
.terminated
||
step
.truncated
{
break
;
}
state
=
step
.state
;
}
println!
(
"episode {episode} with total reward of {total_reward}"
);
for
_
in
0
..
TRAINING_ITERATIONS
{
agent
.train
(
TRAINING_BATCH_SIZE
)
?
;
}
}
println!
(
"Testing..."
);
agent
.train
=
false
;
for
episode
in
0
..
10
{
// let mut state = env.reset(episode as u64)?;
let
mut
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
let
mut
total_reward
=
0.0
;
for
_
in
0
..
EPISODE_LENGTH
{
let
mut
action
=
2.0
*
agent
.actions
(
&
state
)
?
;
action
=
action
.clamp
(
-
2.0
,
2.0
);
let
step
=
env
.step
(
vec!
[
action
])
?
;
total_reward
+=
step
.reward
;
if
step
.terminated
||
step
.truncated
{
break
;
}
state
=
step
.state
;
}
println!
(
"episode {episode} with total reward of {total_reward}"
);
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/dqn.rs
0 → 100644
View file @
25d2752f
use
std
::
collections
::
VecDeque
;
use
rand
::
distributions
::
Uniform
;
use
rand
::{
thread_rng
,
Rng
};
use
candle
::{
DType
,
Device
,
Module
,
Result
,
Tensor
};
use
candle_nn
::
loss
::
mse
;
use
candle_nn
::{
linear
,
seq
,
Activation
,
AdamW
,
Optimizer
,
VarBuilder
,
VarMap
};
use
crate
::
gym_env
::
GymEnv
;
const
DEVICE
:
Device
=
Device
::
Cpu
;
const
EPISODES
:
usize
=
200
;
const
BATCH_SIZE
:
usize
=
64
;
const
GAMMA
:
f64
=
0.99
;
const
LEARNING_RATE
:
f64
=
0.01
;
pub
fn
run
()
->
Result
<
()
>
{
let
env
=
GymEnv
::
new
(
"CartPole-v1"
)
?
;
// Build the model that predicts the estimated rewards given a specific state.
let
var_map
=
VarMap
::
new
();
let
vb
=
VarBuilder
::
from_varmap
(
&
var_map
,
DType
::
F32
,
&
DEVICE
);
let
observation_space
=
*
env
.observation_space
()
.first
()
.unwrap
();
let
model
=
seq
()
.add
(
linear
(
observation_space
,
64
,
vb
.pp
(
"linear_in"
))
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
64
,
env
.action_space
(),
vb
.pp
(
"linear_out"
))
?
);
let
mut
optimizer
=
AdamW
::
new_lr
(
var_map
.all_vars
(),
LEARNING_RATE
)
?
;
// Initialize the model's memory.
let
mut
memory
=
VecDeque
::
with_capacity
(
10000
);
// Start the training loop.
let
mut
state
=
env
.reset
(
0
)
?
;
let
mut
episode
=
0
;
let
mut
accumulate_rewards
=
0.0
;
while
episode
<
EPISODES
{
// Given the current state, predict the estimated rewards, and take the
// action that is expected to return the most rewards.
let
estimated_rewards
=
model
.forward
(
&
state
.unsqueeze
(
0
)
?
)
?
;
let
action
:
u32
=
estimated_rewards
.squeeze
(
0
)
?
.argmax
(
0
)
?
.to_scalar
()
?
;
// Take that action in the environment, and memorize the outcome:
// - the state for which the action was taken
// - the action taken
// - the new state resulting of taking that action
// - the actual rewards of taking that action
// - whether the environment reached a terminal state or not (e.g. game over)
let
step
=
env
.step
(
action
)
?
;
accumulate_rewards
+=
step
.reward
;
memory
.push_back
((
state
,
action
,
step
.state
.clone
(),
step
.reward
,
step
.terminated
||
step
.truncated
,
));
state
=
step
.state
;
// If there's enough entries in the memory, perform a learning step, where
// BATCH_SIZE transitions will be sampled from the memory and will be
// fed to the model so that it performs a backward pass.
if
memory
.len
()
>
BATCH_SIZE
{
// Sample randomly from the memory.
let
batch
=
thread_rng
()
.sample_iter
(
Uniform
::
from
(
0
..
memory
.len
()))
.take
(
BATCH_SIZE
)
.map
(|
i
|
memory
.get
(
i
)
.unwrap
()
.clone
())
.collect
::
<
Vec
<
_
>>
();
// Group all the samples together into tensors with the appropriate shape.
let
states
:
Vec
<
_
>
=
batch
.iter
()
.map
(|
e
|
e
.0
.clone
())
.collect
();
let
states
=
Tensor
::
stack
(
&
states
,
0
)
?
;
let
actions
=
batch
.iter
()
.map
(|
e
|
e
.1
);
let
actions
=
Tensor
::
from_iter
(
actions
,
&
DEVICE
)
?
.unsqueeze
(
1
)
?
;
let
next_states
:
Vec
<
_
>
=
batch
.iter
()
.map
(|
e
|
e
.2
.clone
())
.collect
();
let
next_states
=
Tensor
::
stack
(
&
next_states
,
0
)
?
;
let
rewards
=
batch
.iter
()
.map
(|
e
|
e
.3
as
f32
);
let
rewards
=
Tensor
::
from_iter
(
rewards
,
&
DEVICE
)
?
.unsqueeze
(
1
)
?
;
let
non_final_mask
=
batch
.iter
()
.map
(|
e
|
!
e
.4
as
u8
as
f32
);
let
non_final_mask
=
Tensor
::
from_iter
(
non_final_mask
,
&
DEVICE
)
?
.unsqueeze
(
1
)
?
;
// Get the estimated rewards for the actions that where taken at each step.
let
estimated_rewards
=
model
.forward
(
&
states
)
?
;
let
x
=
estimated_rewards
.gather
(
&
actions
,
1
)
?
;
// Get the maximum expected rewards for the next state, apply them a discount rate
// GAMMA and add them to the rewards that were actually gathered on the current state.
// If the next state is a terminal state, just omit maximum estimated
// rewards for that state.
let
expected_rewards
=
model
.forward
(
&
next_states
)
?
.detach
();
let
y
=
expected_rewards
.max_keepdim
(
1
)
?
;
let
y
=
(
y
*
GAMMA
*
non_final_mask
+
rewards
)
?
;
// Compare the estimated rewards with the maximum expected rewards and
// perform the backward step.
let
loss
=
mse
(
&
x
,
&
y
)
?
;
optimizer
.backward_step
(
&
loss
)
?
;
}
// If we are on a terminal state, reset the environment and log how it went.
if
step
.terminated
||
step
.truncated
{
episode
+=
1
;
println!
(
"Episode {episode} | Rewards {}"
,
accumulate_rewards
as
i64
);
state
=
env
.reset
(
0
)
?
;
accumulate_rewards
=
0.0
;
}
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/gym_env.rs
0 → 100644
View file @
25d2752f
#![allow(unused)]
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
use
candle
::{
Device
,
Result
,
Tensor
};
use
pyo3
::
prelude
::
*
;
use
pyo3
::
types
::
PyDict
;
/// The return value for a step.
#[derive(Debug)]
pub
struct
Step
<
A
>
{
pub
state
:
Tensor
,
pub
action
:
A
,
pub
reward
:
f64
,
pub
terminated
:
bool
,
pub
truncated
:
bool
,
}
impl
<
A
:
Copy
>
Step
<
A
>
{
/// Returns a copy of this step changing the observation tensor.
pub
fn
copy_with_obs
(
&
self
,
state
:
&
Tensor
)
->
Step
<
A
>
{
Step
{
state
:
state
.clone
(),
action
:
self
.action
,
reward
:
self
.reward
,
terminated
:
self
.terminated
,
truncated
:
self
.truncated
,
}
}
}
/// An OpenAI Gym session.
pub
struct
GymEnv
{
env
:
PyObject
,
action_space
:
usize
,
observation_space
:
Vec
<
usize
>
,
}
fn
w
(
res
:
PyErr
)
->
candle
::
Error
{
candle
::
Error
::
wrap
(
res
)
}
impl
GymEnv
{
/// Creates a new session of the specified OpenAI Gym environment.
pub
fn
new
(
name
:
&
str
)
->
Result
<
GymEnv
>
{
Python
::
with_gil
(|
py
|
{
let
gym
=
py
.import_bound
(
"gymnasium"
)
?
;
let
make
=
gym
.getattr
(
"make"
)
?
;
let
env
=
make
.call1
((
name
,))
?
;
let
action_space
=
env
.getattr
(
"action_space"
)
?
;
let
action_space
=
if
let
Ok
(
val
)
=
action_space
.getattr
(
"n"
)
{
val
.extract
()
?
}
else
{
let
action_space
:
Vec
<
usize
>
=
action_space
.getattr
(
"shape"
)
?
.extract
()
?
;
action_space
[
0
]
};
let
observation_space
=
env
.getattr
(
"observation_space"
)
?
;
let
observation_space
=
observation_space
.getattr
(
"shape"
)
?
.extract
()
?
;
Ok
(
GymEnv
{
env
:
env
.into
(),
action_space
,
observation_space
,
})
})
.map_err
(
w
)
}
/// Resets the environment, returning the observation tensor.
pub
fn
reset
(
&
self
,
seed
:
u64
)
->
Result
<
Tensor
>
{
let
state
:
Vec
<
f32
>
=
Python
::
with_gil
(|
py
|
{
let
kwargs
=
PyDict
::
new_bound
(
py
);
kwargs
.set_item
(
"seed"
,
seed
)
?
;
let
state
=
self
.env
.call_method_bound
(
py
,
"reset"
,
(),
Some
(
&
kwargs
))
?
;
state
.bind
(
py
)
.get_item
(
0
)
?
.extract
()
})
.map_err
(
w
)
?
;
Tensor
::
new
(
state
,
&
Device
::
Cpu
)
}
/// Applies an environment step using the specified action.
pub
fn
step
<
A
:
pyo3
::
IntoPy
<
pyo3
::
Py
<
pyo3
::
PyAny
>>
+
Clone
>
(
&
self
,
action
:
A
,
)
->
Result
<
Step
<
A
>>
{
let
(
state
,
reward
,
terminated
,
truncated
)
=
Python
::
with_gil
(|
py
|
{
let
step
=
self
.env
.call_method_bound
(
py
,
"step"
,
(
action
.clone
(),),
None
)
?
;
let
step
=
step
.bind
(
py
);
let
state
:
Vec
<
f32
>
=
step
.get_item
(
0
)
?
.extract
()
?
;
let
reward
:
f64
=
step
.get_item
(
1
)
?
.extract
()
?
;
let
terminated
:
bool
=
step
.get_item
(
2
)
?
.extract
()
?
;
let
truncated
:
bool
=
step
.get_item
(
3
)
?
.extract
()
?
;
Ok
((
state
,
reward
,
terminated
,
truncated
))
})
.map_err
(
w
)
?
;
let
state
=
Tensor
::
new
(
state
,
&
Device
::
Cpu
)
?
;
Ok
(
Step
{
state
,
action
,
reward
,
terminated
,
truncated
,
})
}
/// Returns the number of allowed actions for this environment.
pub
fn
action_space
(
&
self
)
->
usize
{
self
.action_space
}
/// Returns the shape of the observation tensors.
pub
fn
observation_space
(
&
self
)
->
&
[
usize
]
{
&
self
.observation_space
}
}
candle-examples/examples/reinforcement-learning/main.rs
0 → 100644
View file @
25d2752f
#![allow(unused)]
#[cfg(any(feature
=
"mkl"
,
feature
=
"mkl-dynamic"
))]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
candle
::
Result
;
use
clap
::{
Parser
,
Subcommand
};
mod
gym_env
;
mod
vec_gym_env
;
mod
ddpg
;
mod
dqn
;
mod
policy_gradient
;
#[derive(Parser)]
struct
Args
{
#[command(subcommand)]
command
:
Command
,
}
#[derive(Subcommand)]
enum
Command
{
Pg
,
Ddpg
,
Dqn
,
}
fn
main
()
->
Result
<
()
>
{
let
args
=
Args
::
parse
();
match
args
.command
{
Command
::
Pg
=>
policy_gradient
::
run
()
?
,
Command
::
Ddpg
=>
ddpg
::
run
()
?
,
Command
::
Dqn
=>
dqn
::
run
()
?
,
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/policy_gradient.rs
0 → 100644
View file @
25d2752f
use
super
::
gym_env
::{
GymEnv
,
Step
};
use
candle
::{
DType
,
Device
,
Error
,
Module
,
Result
,
Tensor
};
use
candle_nn
::{
linear
,
ops
::
log_softmax
,
ops
::
softmax
,
sequential
::
seq
,
Activation
,
AdamW
,
Optimizer
,
ParamsAdamW
,
VarBuilder
,
VarMap
,
};
use
rand
::{
distributions
::
Distribution
,
rngs
::
ThreadRng
,
Rng
};
fn
new_model
(
input_shape
:
&
[
usize
],
num_actions
:
usize
,
dtype
:
DType
,
device
:
&
Device
,
)
->
Result
<
(
impl
Module
,
VarMap
)
>
{
let
input_size
=
input_shape
.iter
()
.product
();
let
mut
varmap
=
VarMap
::
new
();
let
var_builder
=
VarBuilder
::
from_varmap
(
&
varmap
,
dtype
,
device
);
let
model
=
seq
()
.add
(
linear
(
input_size
,
32
,
var_builder
.pp
(
"lin1"
))
?
)
.add
(
Activation
::
Relu
)
.add
(
linear
(
32
,
num_actions
,
var_builder
.pp
(
"lin2"
))
?
);
Ok
((
model
,
varmap
))
}
fn
accumulate_rewards
(
steps
:
&
[
Step
<
i64
>
])
->
Vec
<
f64
>
{
let
mut
rewards
:
Vec
<
f64
>
=
steps
.iter
()
.map
(|
s
|
s
.reward
)
.collect
();
let
mut
acc_reward
=
0f64
;
for
(
i
,
reward
)
in
rewards
.iter_mut
()
.enumerate
()
.rev
()
{
if
steps
[
i
]
.terminated
{
acc_reward
=
0.0
;
}
acc_reward
+=
*
reward
;
*
reward
=
acc_reward
;
}
rewards
}
fn
weighted_sample
(
probs
:
Vec
<
f32
>
,
rng
:
&
mut
ThreadRng
)
->
Result
<
usize
>
{
let
distribution
=
rand
::
distributions
::
WeightedIndex
::
new
(
probs
)
.map_err
(
Error
::
wrap
)
?
;
let
mut
rng
=
rng
;
Ok
(
distribution
.sample
(
&
mut
rng
))
}
pub
fn
run
()
->
Result
<
()
>
{
let
env
=
GymEnv
::
new
(
"CartPole-v1"
)
?
;
println!
(
"action space: {:?}"
,
env
.action_space
());
println!
(
"observation space: {:?}"
,
env
.observation_space
());
let
(
model
,
varmap
)
=
new_model
(
env
.observation_space
(),
env
.action_space
(),
DType
::
F32
,
&
Device
::
Cpu
,
)
?
;
let
optimizer_params
=
ParamsAdamW
{
lr
:
0.01
,
weight_decay
:
0.01
,
..
Default
::
default
()
};
let
mut
optimizer
=
AdamW
::
new
(
varmap
.all_vars
(),
optimizer_params
)
?
;
let
mut
rng
=
rand
::
thread_rng
();
for
epoch_idx
in
0
..
100
{
let
mut
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
let
mut
steps
:
Vec
<
Step
<
i64
>>
=
vec!
[];
loop
{
let
action
=
{
let
action_probs
:
Vec
<
f32
>
=
softmax
(
&
model
.forward
(
&
state
.detach
()
.unsqueeze
(
0
)
?
)
?
,
1
)
?
.squeeze
(
0
)
?
.to_vec1
()
?
;
weighted_sample
(
action_probs
,
&
mut
rng
)
?
as
i64
};
let
step
=
env
.step
(
action
)
?
;
steps
.push
(
step
.copy_with_obs
(
&
state
));
if
step
.terminated
||
step
.truncated
{
state
=
env
.reset
(
rng
.gen
::
<
u64
>
())
?
;
if
steps
.len
()
>
5000
{
break
;
}
}
else
{
state
=
step
.state
;
}
}
let
total_reward
:
f64
=
steps
.iter
()
.map
(|
s
|
s
.reward
)
.sum
();
let
episodes
:
i64
=
steps
.iter
()
.map
(|
s
|
(
s
.terminated
||
s
.truncated
)
as
i64
)
.sum
();
println!
(
"epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}"
,
epoch_idx
,
episodes
,
total_reward
/
episodes
as
f64
);
let
batch_size
=
steps
.len
();
let
rewards
=
Tensor
::
from_vec
(
accumulate_rewards
(
&
steps
),
batch_size
,
&
Device
::
Cpu
)
?
.to_dtype
(
DType
::
F32
)
?
.detach
();
let
actions_mask
=
{
let
actions
:
Vec
<
i64
>
=
steps
.iter
()
.map
(|
s
|
s
.action
)
.collect
();
let
actions_mask
:
Vec
<
Tensor
>
=
actions
.iter
()
.map
(|
&
action
|
{
// One-hot encoding
let
mut
action_mask
=
vec!
[
0.0
;
env
.action_space
()];
action_mask
[
action
as
usize
]
=
1.0
;
Tensor
::
from_vec
(
action_mask
,
env
.action_space
(),
&
Device
::
Cpu
)
.unwrap
()
.to_dtype
(
DType
::
F32
)
.unwrap
()
})
.collect
();
Tensor
::
stack
(
&
actions_mask
,
0
)
?
.detach
()
};
let
states
=
{
let
states
:
Vec
<
Tensor
>
=
steps
.into_iter
()
.map
(|
s
|
s
.state
)
.collect
();
Tensor
::
stack
(
&
states
,
0
)
?
.detach
()
};
let
log_probs
=
actions_mask
.mul
(
&
log_softmax
(
&
model
.forward
(
&
states
)
?
,
1
)
?
)
?
.sum
(
1
)
?
;
let
loss
=
rewards
.mul
(
&
log_probs
)
?
.neg
()
?
.mean_all
()
?
;
optimizer
.backward_step
(
&
loss
)
?
;
}
Ok
(())
}
candle-examples/examples/reinforcement-learning/vec_gym_env.rs
0 → 100644
View file @
25d2752f
#![allow(unused)]
//! Vectorized version of the gym environment.
use
candle
::{
DType
,
Device
,
Result
,
Tensor
};
use
pyo3
::
prelude
::
*
;
use
pyo3
::
types
::
PyDict
;
#[derive(Debug)]
pub
struct
Step
{
pub
obs
:
Tensor
,
pub
reward
:
Tensor
,
pub
is_done
:
Tensor
,
}
pub
struct
VecGymEnv
{
env
:
PyObject
,
action_space
:
usize
,
observation_space
:
Vec
<
usize
>
,
}
fn
w
(
res
:
PyErr
)
->
candle
::
Error
{
candle
::
Error
::
wrap
(
res
)
}
impl
VecGymEnv
{
pub
fn
new
(
name
:
&
str
,
img_dir
:
Option
<&
str
>
,
nprocesses
:
usize
)
->
Result
<
VecGymEnv
>
{
Python
::
with_gil
(|
py
|
{
let
sys
=
py
.import_bound
(
"sys"
)
?
;
let
path
=
sys
.getattr
(
"path"
)
?
;
let
_
=
path
.call_method1
(
"append"
,
(
"candle-examples/examples/reinforcement-learning"
,),
)
?
;
let
gym
=
py
.import_bound
(
"atari_wrappers"
)
?
;
let
make
=
gym
.getattr
(
"make"
)
?
;
let
env
=
make
.call1
((
name
,
img_dir
,
nprocesses
))
?
;
let
action_space
=
env
.getattr
(
"action_space"
)
?
;
let
action_space
=
action_space
.getattr
(
"n"
)
?
.extract
()
?
;
let
observation_space
=
env
.getattr
(
"observation_space"
)
?
;
let
observation_space
:
Vec
<
usize
>
=
observation_space
.getattr
(
"shape"
)
?
.extract
()
?
;
let
observation_space
=
[
vec!
[
nprocesses
]
.as_slice
(),
observation_space
.as_slice
()]
.concat
();
Ok
(
VecGymEnv
{
env
:
env
.into
(),
action_space
,
observation_space
,
})
})
.map_err
(
w
)
}
pub
fn
reset
(
&
self
)
->
Result
<
Tensor
>
{
let
obs
=
Python
::
with_gil
(|
py
|
{
let
obs
=
self
.env
.call_method0
(
py
,
"reset"
)
?
;
let
obs
=
obs
.call_method0
(
py
,
"flatten"
)
?
;
obs
.extract
::
<
Vec
<
f32
>>
(
py
)
})
.map_err
(
w
)
?
;
Tensor
::
new
(
obs
,
&
Device
::
Cpu
)
?
.reshape
(
self
.observation_space
.as_slice
())
}
pub
fn
step
(
&
self
,
action
:
Vec
<
usize
>
)
->
Result
<
Step
>
{
let
(
obs
,
reward
,
is_done
)
=
Python
::
with_gil
(|
py
|
{
let
step
=
self
.env
.call_method_bound
(
py
,
"step"
,
(
action
,),
None
)
?
;
let
step
=
step
.bind
(
py
);
let
obs
=
step
.get_item
(
0
)
?
.call_method
(
"flatten"
,
(),
None
)
?
;
let
obs_buffer
=
pyo3
::
buffer
::
PyBuffer
::
get_bound
(
&
obs
)
?
;
let
obs
:
Vec
<
u8
>
=
obs_buffer
.to_vec
(
py
)
?
;
let
reward
:
Vec
<
f32
>
=
step
.get_item
(
1
)
?
.extract
()
?
;
let
is_done
:
Vec
<
f32
>
=
step
.get_item
(
2
)
?
.extract
()
?
;
Ok
((
obs
,
reward
,
is_done
))
})
.map_err
(
w
)
?
;
let
obs
=
Tensor
::
from_vec
(
obs
,
self
.observation_space
.as_slice
(),
&
Device
::
Cpu
)
?
.to_dtype
(
DType
::
F32
)
?
;
let
reward
=
Tensor
::
new
(
reward
,
&
Device
::
Cpu
)
?
;
let
is_done
=
Tensor
::
new
(
is_done
,
&
Device
::
Cpu
)
?
;
Ok
(
Step
{
obs
,
reward
,
is_done
,
})
}
pub
fn
action_space
(
&
self
)
->
usize
{
self
.action_space
}
pub
fn
observation_space
(
&
self
)
->
&
[
usize
]
{
&
self
.observation_space
}
}
candle-examples/examples/replit-code/README.md
0 → 100644
View file @
25d2752f
# candle-replit-code: code completion specialized model.
[
replit-code-v1_5-3b
](
https://huggingface.co/replit/replit-code-v1_5-3b
)
is a
language model specialized for code completion. This model uses 3.3B parameters
in
`bfloat16`
(so the GPU version will only work on recent nvidia cards).
## Running some example
```
bash
cargo run
--example
replit-code
--release
--
--prompt
'def fibonacci(n): '
```
This produces the following output.
```
def fibonacci(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
print(a, end=' ')
a, b = b, a+b
print()
def fibonacci_loop(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
result = []
a, b = 0, 1
while a < n:
result.append(a)
a, b = b, a+b
return result
def fibonacci_generator(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
yield a
a, b = b, a+b
```
candle-examples/examples/replit-code/main.rs
0 → 100644
View file @
25d2752f
#[cfg(any(feature
=
"mkl"
,
feature
=
"mkl-dynamic"
))]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
anyhow
::{
Error
as
E
,
Result
};
use
clap
::
Parser
;
use
candle_transformers
::
models
::
mpt
::{
Config
,
Model
as
M
};
use
candle_transformers
::
models
::
quantized_mpt
::
Model
as
Q
;
use
candle
::{
DType
,
Device
,
Tensor
};
use
candle_nn
::
VarBuilder
;
use
candle_transformers
::
generation
::
LogitsProcessor
;
use
hf_hub
::{
api
::
sync
::
Api
,
Repo
,
RepoType
};
use
tokenizers
::
Tokenizer
;
enum
Model
{
M
(
M
),
Q
(
Q
),
}
impl
Model
{
fn
forward
(
&
mut
self
,
xs
:
&
Tensor
)
->
candle
::
Result
<
Tensor
>
{
match
self
{
Self
::
M
(
model
)
=>
model
.forward
(
xs
),
Self
::
Q
(
model
)
=>
model
.forward
(
xs
),
}
}
}
struct
TextGeneration
{
model
:
Model
,
device
:
Device
,
tokenizer
:
Tokenizer
,
logits_processor
:
LogitsProcessor
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
verbose_prompt
:
bool
,
}
impl
TextGeneration
{
#[allow(clippy::too_many_arguments)]
fn
new
(
model
:
Model
,
tokenizer
:
Tokenizer
,
seed
:
u64
,
temp
:
Option
<
f64
>
,
top_p
:
Option
<
f64
>
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
verbose_prompt
:
bool
,
device
:
&
Device
,
)
->
Self
{
let
logits_processor
=
LogitsProcessor
::
new
(
seed
,
temp
,
top_p
);
Self
{
model
,
tokenizer
,
logits_processor
,
repeat_penalty
,
repeat_last_n
,
verbose_prompt
,
device
:
device
.clone
(),
}
}
fn
run
(
&
mut
self
,
prompt
:
&
str
,
sample_len
:
usize
)
->
Result
<
()
>
{
use
std
::
io
::
Write
;
println!
(
"starting the inference loop"
);
let
tokens
=
self
.tokenizer
.encode
(
prompt
,
true
)
.map_err
(
E
::
msg
)
?
;
if
tokens
.is_empty
()
{
anyhow
::
bail!
(
"Empty prompts are not supported in the phi model."
)
}
if
self
.verbose_prompt
{
for
(
token
,
id
)
in
tokens
.get_tokens
()
.iter
()
.zip
(
tokens
.get_ids
()
.iter
())
{
let
token
=
token
.replace
(
'▁'
,
" "
)
.replace
(
"<0x0A>"
,
"
\n
"
);
println!
(
"{id:7} -> '{token}'"
);
}
}
let
mut
tokens
=
tokens
.get_ids
()
.to_vec
();
let
mut
generated_tokens
=
0u
size
;
let
eos_token
=
match
self
.tokenizer
.get_vocab
(
true
)
.get
(
"<|endoftext|>"
)
{
Some
(
token
)
=>
*
token
,
None
=>
anyhow
::
bail!
(
"cannot find the endoftext token"
),
};
print!
(
"{prompt}"
);
std
::
io
::
stdout
()
.flush
()
?
;
let
start_gen
=
std
::
time
::
Instant
::
now
();
for
index
in
0
..
sample_len
{
let
context_size
=
if
index
>
0
{
1
}
else
{
tokens
.len
()
};
let
ctxt
=
&
tokens
[
tokens
.len
()
.saturating_sub
(
context_size
)
..
];
let
input
=
Tensor
::
new
(
ctxt
,
&
self
.device
)
?
.unsqueeze
(
0
)
?
;
let
logits
=
self
.model
.forward
(
&
input
)
?
;
let
logits
=
logits
.squeeze
(
0
)
?
.to_dtype
(
DType
::
F32
)
?
;
let
logits
=
if
self
.repeat_penalty
==
1
.
{
logits
}
else
{
let
start_at
=
tokens
.len
()
.saturating_sub
(
self
.repeat_last_n
);
candle_transformers
::
utils
::
apply_repeat_penalty
(
&
logits
,
self
.repeat_penalty
,
&
tokens
[
start_at
..
],
)
?
};
let
next_token
=
self
.logits_processor
.sample
(
&
logits
)
?
;
tokens
.push
(
next_token
);
generated_tokens
+=
1
;
if
next_token
==
eos_token
{
break
;
}
let
token
=
self
.tokenizer
.decode
(
&
[
next_token
],
true
)
.map_err
(
E
::
msg
)
?
;
print!
(
"{token}"
);
std
::
io
::
stdout
()
.flush
()
?
;
}
let
dt
=
start_gen
.elapsed
();
println!
(
"
\n
{generated_tokens} tokens generated ({:.2} token/s)"
,
generated_tokens
as
f64
/
dt
.as_secs_f64
(),
);
Ok
(())
}
}
#[derive(Parser,
Debug)]
#[command(author,
version,
about,
long_about
=
None)]
struct
Args
{
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing
:
bool
,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt
:
bool
,
#[arg(long)]
prompt
:
String
,
/// The temperature used to generate samples.
#[arg(long)]
temperature
:
Option
<
f64
>
,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p
:
Option
<
f64
>
,
/// The seed to use when generating random samples.
#[arg(long,
default_value_t
=
299792458
)]
seed
:
u64
,
/// The length of the sample to generate (in tokens).
#[arg(long,
short
=
'n'
,
default_value_t
=
1000
)]
sample_len
:
usize
,
#[arg(long)]
model_id
:
Option
<
String
>
,
#[arg(long)]
revision
:
Option
<
String
>
,
#[arg(long)]
quantized
:
bool
,
#[arg(long)]
weight_file
:
Option
<
String
>
,
#[arg(long)]
tokenizer
:
Option
<
String
>
,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long,
default_value_t
=
1
.
)]
repeat_penalty
:
f32
,
/// The context size to consider for the repeat penalty.
#[arg(long,
default_value_t
=
64
)]
repeat_last_n
:
usize
,
}
fn
main
()
->
Result
<
()
>
{
use
tracing_chrome
::
ChromeLayerBuilder
;
use
tracing_subscriber
::
prelude
::
*
;
let
args
=
Args
::
parse
();
let
_
guard
=
if
args
.tracing
{
let
(
chrome_layer
,
guard
)
=
ChromeLayerBuilder
::
new
()
.build
();
tracing_subscriber
::
registry
()
.with
(
chrome_layer
)
.init
();
Some
(
guard
)
}
else
{
None
};
println!
(
"avx: {}, neon: {}, simd128: {}, f16c: {}"
,
candle
::
utils
::
with_avx
(),
candle
::
utils
::
with_neon
(),
candle
::
utils
::
with_simd128
(),
candle
::
utils
::
with_f16c
()
);
println!
(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}"
,
args
.temperature
.unwrap_or
(
0
.
),
args
.repeat_penalty
,
args
.repeat_last_n
);
let
start
=
std
::
time
::
Instant
::
now
();
let
api
=
Api
::
new
()
?
;
let
model_id
=
match
args
.model_id
{
Some
(
model_id
)
=>
model_id
.to_string
(),
None
=>
"lmz/candle-replit-code"
.to_string
(),
};
let
revision
=
match
args
.revision
{
Some
(
rev
)
=>
rev
.to_string
(),
None
=>
"main"
.to_string
(),
};
let
repo
=
api
.repo
(
Repo
::
with_revision
(
model_id
,
RepoType
::
Model
,
revision
));
let
tokenizer_filename
=
match
args
.tokenizer
{
Some
(
file
)
=>
std
::
path
::
PathBuf
::
from
(
file
),
None
=>
repo
.get
(
"tokenizer.json"
)
?
,
};
let
filename
=
match
args
.weight_file
{
Some
(
weight_file
)
=>
std
::
path
::
PathBuf
::
from
(
weight_file
),
None
=>
{
if
args
.quantized
{
repo
.get
(
"model-replit-code-v1_5-q4k.gguf"
)
?
}
else
{
repo
.get
(
"model.safetensors"
)
?
}
}
};
println!
(
"retrieved the files in {:?}"
,
start
.elapsed
());
let
tokenizer
=
Tokenizer
::
from_file
(
tokenizer_filename
)
.map_err
(
E
::
msg
)
?
;
let
start
=
std
::
time
::
Instant
::
now
();
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
config
=
Config
::
replit_code_v1_5_3b
();
let
model
=
if
args
.quantized
{
let
vb
=
candle_transformers
::
quantized_var_builder
::
VarBuilder
::
from_gguf
(
&
filename
,
&
device
)
?
;
Model
::
Q
(
Q
::
new
(
&
config
,
vb
.pp
(
"transformer"
))
?
)
}
else
{
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
[
filename
],
DType
::
F32
,
&
device
)
?
};
Model
::
M
(
M
::
new
(
&
config
,
vb
.pp
(
"transformer"
))
?
)
};
println!
(
"loaded the model in {:?}"
,
start
.elapsed
());
let
mut
pipeline
=
TextGeneration
::
new
(
model
,
tokenizer
,
args
.seed
,
args
.temperature
,
args
.top_p
,
args
.repeat_penalty
,
args
.repeat_last_n
,
args
.verbose_prompt
,
&
device
,
);
pipeline
.run
(
&
args
.prompt
,
args
.sample_len
)
?
;
Ok
(())
}
candle-examples/examples/repvgg/README.md
0 → 100644
View file @
25d2752f
# candle-repvgg
[
RepVGG: Making VGG-style ConvNets Great Again
](
https://arxiv.org/abs/2101.03697
)
.
This candle implementation uses a pre-trained RepVGG network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$
cargo
run
--
example
repvgg
--
release
--
--
image
candle
-
examples
/
examples
/
yolo
-
v8
/
assets
/
bike
.
jpg
loaded
image
Tensor
[
dims
3
,
224
,
224
;
f32
]
model
built
mountain
bike
,
all
-
terrain
bike
,
off
-
roader
:
61.70
%
bicycle
-
built
-
for
-
two
,
tandem
bicycle
,
tandem
:
33.14
%
unicycle
,
monocycle
:
4.88
%
crash
helmet
:
0.15
%
moped
:
0.04
%
```
candle-examples/examples/repvgg/main.rs
0 → 100644
View file @
25d2752f
#[cfg(feature
=
"mkl"
)]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
clap
::{
Parser
,
ValueEnum
};
use
candle
::{
DType
,
IndexOp
,
D
};
use
candle_nn
::{
Module
,
VarBuilder
};
use
candle_transformers
::
models
::
repvgg
;
#[derive(Clone,
Copy,
Debug,
ValueEnum)]
enum
Which
{
A0
,
A1
,
A2
,
B0
,
B1
,
B2
,
B3
,
B1G4
,
B2G4
,
B3G4
,
}
impl
Which
{
fn
model_filename
(
&
self
)
->
String
{
let
name
=
match
self
{
Self
::
A0
=>
"a0"
,
Self
::
A1
=>
"a1"
,
Self
::
A2
=>
"a2"
,
Self
::
B0
=>
"b0"
,
Self
::
B1
=>
"b1"
,
Self
::
B2
=>
"b2"
,
Self
::
B3
=>
"b3"
,
Self
::
B1G4
=>
"b1g4"
,
Self
::
B2G4
=>
"b2g4"
,
Self
::
B3G4
=>
"b3g4"
,
};
format!
(
"timm/repvgg_{}.rvgg_in1k"
,
name
)
}
fn
config
(
&
self
)
->
repvgg
::
Config
{
match
self
{
Self
::
A0
=>
repvgg
::
Config
::
a0
(),
Self
::
A1
=>
repvgg
::
Config
::
a1
(),
Self
::
A2
=>
repvgg
::
Config
::
a2
(),
Self
::
B0
=>
repvgg
::
Config
::
b0
(),
Self
::
B1
=>
repvgg
::
Config
::
b1
(),
Self
::
B2
=>
repvgg
::
Config
::
b2
(),
Self
::
B3
=>
repvgg
::
Config
::
b3
(),
Self
::
B1G4
=>
repvgg
::
Config
::
b1g4
(),
Self
::
B2G4
=>
repvgg
::
Config
::
b2g4
(),
Self
::
B3G4
=>
repvgg
::
Config
::
b3g4
(),
}
}
}
#[derive(Parser)]
struct
Args
{
#[arg(long)]
model
:
Option
<
String
>
,
#[arg(long)]
image
:
String
,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
#[arg(value_enum,
long,
default_value_t=Which::A0)]
which
:
Which
,
}
pub
fn
main
()
->
anyhow
::
Result
<
()
>
{
let
args
=
Args
::
parse
();
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
image
=
candle_examples
::
imagenet
::
load_image224
(
args
.image
)
?
.to_device
(
&
device
)
?
;
println!
(
"loaded image {image:?}"
);
let
model_file
=
match
args
.model
{
None
=>
{
let
model_name
=
args
.which
.model_filename
();
let
api
=
hf_hub
::
api
::
sync
::
Api
::
new
()
?
;
let
api
=
api
.model
(
model_name
);
api
.get
(
"model.safetensors"
)
?
}
Some
(
model
)
=>
model
.into
(),
};
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
[
model_file
],
DType
::
F32
,
&
device
)
?
};
let
model
=
repvgg
::
repvgg
(
&
args
.which
.config
(),
1000
,
vb
)
?
;
println!
(
"model built"
);
let
logits
=
model
.forward
(
&
image
.unsqueeze
(
0
)
?
)
?
;
let
prs
=
candle_nn
::
ops
::
softmax
(
&
logits
,
D
::
Minus1
)
?
.i
(
0
)
?
.to_vec1
::
<
f32
>
()
?
;
let
mut
prs
=
prs
.iter
()
.enumerate
()
.collect
::
<
Vec
<
_
>>
();
prs
.sort_by
(|(
_
,
p1
),
(
_
,
p2
)|
p2
.total_cmp
(
p1
));
for
&
(
category_idx
,
pr
)
in
prs
.iter
()
.take
(
5
)
{
println!
(
"{:24}: {:.2}%"
,
candle_examples
::
imagenet
::
CLASSES
[
category_idx
],
100
.
*
pr
);
}
Ok
(())
}
candle-examples/examples/resnet/README.md
0 → 100644
View file @
25d2752f
# candle-resnet
A candle implementation of inference using a pre-trained
[
ResNet
](
https://arxiv.org/abs/1512.03385
)
.
This uses a classification head trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$
cargo
run
--
example
resnet
--
release
--
--
image
tiger
.
jpg
loaded
image
Tensor
[
dims
3
,
224
,
224
;
f32
]
model
built
tiger
,
Panthera
tigris
:
90.21
%
tiger
cat
:
8.93
%
lion
,
king
of
beasts
,
Panthera
leo
:
0.35
%
leopard
,
Panthera
pardus
:
0.16
%
jaguar
,
panther
,
Panthera
onca
,
Felis
onca
:
0.09
%
```
candle-examples/examples/resnet/export_models.py
0 → 100644
View file @
25d2752f
# This script exports pre-trained model weights in the safetensors format.
import
numpy
as
np
import
torch
import
torchvision
from
safetensors
import
torch
as
stt
m
=
torchvision
.
models
.
resnet50
(
pretrained
=
True
)
stt
.
save_file
(
m
.
state_dict
(),
'resnet50.safetensors'
)
m
=
torchvision
.
models
.
resnet101
(
pretrained
=
True
)
stt
.
save_file
(
m
.
state_dict
(),
'resnet101.safetensors'
)
m
=
torchvision
.
models
.
resnet152
(
pretrained
=
True
)
stt
.
save_file
(
m
.
state_dict
(),
'resnet152.safetensors'
)
candle-examples/examples/resnet/main.rs
0 → 100644
View file @
25d2752f
#[cfg(any(feature
=
"mkl"
,
feature
=
"mkl-dynamic"
))]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
candle
::{
DType
,
IndexOp
,
D
};
use
candle_nn
::{
Module
,
VarBuilder
};
use
candle_transformers
::
models
::
resnet
;
use
clap
::{
Parser
,
ValueEnum
};
#[derive(Clone,
Copy,
Debug,
ValueEnum)]
enum
Which
{
#[value(name
=
"18"
)]
Resnet18
,
#[value(name
=
"34"
)]
Resnet34
,
#[value(name
=
"50"
)]
Resnet50
,
#[value(name
=
"101"
)]
Resnet101
,
#[value(name
=
"152"
)]
Resnet152
,
}
#[derive(Parser)]
struct
Args
{
#[arg(long)]
model
:
Option
<
String
>
,
#[arg(long)]
image
:
String
,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
/// Variant of the model to use.
#[arg(value_enum,
long,
default_value_t
=
Which::Resnet18)]
which
:
Which
,
}
pub
fn
main
()
->
anyhow
::
Result
<
()
>
{
let
args
=
Args
::
parse
();
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
image
=
candle_examples
::
imagenet
::
load_image224
(
args
.image
)
?
.to_device
(
&
device
)
?
;
println!
(
"loaded image {image:?}"
);
let
model_file
=
match
args
.model
{
None
=>
{
let
api
=
hf_hub
::
api
::
sync
::
Api
::
new
()
?
;
let
api
=
api
.model
(
"lmz/candle-resnet"
.into
());
let
filename
=
match
args
.which
{
Which
::
Resnet18
=>
"resnet18.safetensors"
,
Which
::
Resnet34
=>
"resnet34.safetensors"
,
Which
::
Resnet50
=>
"resnet50.safetensors"
,
Which
::
Resnet101
=>
"resnet101.safetensors"
,
Which
::
Resnet152
=>
"resnet152.safetensors"
,
};
api
.get
(
filename
)
?
}
Some
(
model
)
=>
model
.into
(),
};
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
[
model_file
],
DType
::
F32
,
&
device
)
?
};
let
class_count
=
candle_examples
::
imagenet
::
CLASS_COUNT
as
usize
;
let
model
=
match
args
.which
{
Which
::
Resnet18
=>
resnet
::
resnet18
(
class_count
,
vb
)
?
,
Which
::
Resnet34
=>
resnet
::
resnet34
(
class_count
,
vb
)
?
,
Which
::
Resnet50
=>
resnet
::
resnet50
(
class_count
,
vb
)
?
,
Which
::
Resnet101
=>
resnet
::
resnet101
(
class_count
,
vb
)
?
,
Which
::
Resnet152
=>
resnet
::
resnet152
(
class_count
,
vb
)
?
,
};
println!
(
"model built"
);
let
logits
=
model
.forward
(
&
image
.unsqueeze
(
0
)
?
)
?
;
let
prs
=
candle_nn
::
ops
::
softmax
(
&
logits
,
D
::
Minus1
)
?
.i
(
0
)
?
.to_vec1
::
<
f32
>
()
?
;
let
mut
prs
=
prs
.iter
()
.enumerate
()
.collect
::
<
Vec
<
_
>>
();
prs
.sort_by
(|(
_
,
p1
),
(
_
,
p2
)|
p2
.total_cmp
(
p1
));
for
&
(
category_idx
,
pr
)
in
prs
.iter
()
.take
(
5
)
{
println!
(
"{:24}: {:.2}%"
,
candle_examples
::
imagenet
::
CLASSES
[
category_idx
],
100
.
*
pr
);
}
Ok
(())
}
candle-examples/examples/rwkv/README.md
0 → 100644
View file @
25d2752f
## candle-rwkv
The
[
RWKV model
](
https://wiki.rwkv.com/
)
is a recurrent neural network model
with performance on par with transformer architectures. Several variants are
available, candle implements the v5 and v6 versions and can be used with
Eagle 7B(
[
blog post
](
https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers
)
).
```
bash
$
cargo run
--example
rwkv
--release
--
--prompt
"The smallest prime is "
avx:
true
, neon:
false
, simd128:
false
, f16c:
true
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
The smallest prime is ϕ
(
2
)
=
2.
The smallest composite is ϕ
(
3
)
=
3.
The smallest perfect number is ϕ
(
5
)
=
5.
The smallest perfect square is ϕ
(
4
)
=
4.
The smallest perfect cube is ϕ
(
6
)
=
6.
```
candle-examples/examples/rwkv/main.rs
0 → 100644
View file @
25d2752f
#[cfg(feature
=
"mkl"
)]
extern
crate
intel_mkl_src
;
#[cfg(feature
=
"accelerate"
)]
extern
crate
accelerate_src
;
use
anyhow
::
Result
;
use
clap
::{
Parser
,
ValueEnum
};
use
candle_transformers
::
models
::
quantized_rwkv_v5
::
Model
as
Q5
;
use
candle_transformers
::
models
::
quantized_rwkv_v6
::
Model
as
Q6
;
use
candle_transformers
::
models
::
rwkv_v5
::{
Config
,
Model
as
M5
,
State
,
Tokenizer
};
use
candle_transformers
::
models
::
rwkv_v6
::
Model
as
M6
;
use
candle
::{
DType
,
Device
,
Tensor
};
use
candle_nn
::
VarBuilder
;
use
candle_transformers
::
generation
::
LogitsProcessor
;
use
hf_hub
::{
api
::
sync
::
Api
,
Repo
,
RepoType
};
const
EOS_TOKEN_ID
:
u32
=
261
;
enum
Model
{
M5
(
M5
),
Q5
(
Q5
),
M6
(
M6
),
Q6
(
Q6
),
}
impl
Model
{
fn
forward
(
&
self
,
xs
:
&
Tensor
,
state
:
&
mut
State
)
->
candle
::
Result
<
Tensor
>
{
match
self
{
Self
::
M5
(
m
)
=>
m
.forward
(
xs
,
state
),
Self
::
Q5
(
m
)
=>
m
.forward
(
xs
,
state
),
Self
::
M6
(
m
)
=>
m
.forward
(
xs
,
state
),
Self
::
Q6
(
m
)
=>
m
.forward
(
xs
,
state
),
}
}
}
struct
TextGeneration
{
model
:
Model
,
config
:
Config
,
device
:
Device
,
tokenizer
:
Tokenizer
,
logits_processor
:
LogitsProcessor
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
}
impl
TextGeneration
{
#[allow(clippy::too_many_arguments)]
fn
new
(
model
:
Model
,
config
:
Config
,
tokenizer
:
Tokenizer
,
seed
:
u64
,
temp
:
Option
<
f64
>
,
top_p
:
Option
<
f64
>
,
repeat_penalty
:
f32
,
repeat_last_n
:
usize
,
device
:
&
Device
,
)
->
Self
{
let
logits_processor
=
LogitsProcessor
::
new
(
seed
,
temp
,
top_p
);
Self
{
model
,
config
,
tokenizer
,
logits_processor
,
repeat_penalty
,
repeat_last_n
,
device
:
device
.clone
(),
}
}
fn
run
(
&
mut
self
,
prompt
:
&
str
,
sample_len
:
usize
)
->
Result
<
()
>
{
use
std
::
io
::
Write
;
let
mut
tokens
=
self
.tokenizer
.encode
(
prompt
)
?
;
let
mut
generated_tokens
=
0u
size
;
let
mut
state
=
State
::
new
(
1
,
&
self
.config
,
&
self
.device
)
?
;
let
mut
next_logits
=
None
;
for
&
t
in
tokens
.iter
()
{
let
input
=
Tensor
::
new
(
&
[[
t
]],
&
self
.device
)
?
;
let
logits
=
self
.model
.forward
(
&
input
,
&
mut
state
)
?
;
next_logits
=
Some
(
logits
);
print!
(
"{}"
,
self
.tokenizer
.decode
(
&
[
t
])
?
)
}
std
::
io
::
stdout
()
.flush
()
?
;
let
start_gen
=
std
::
time
::
Instant
::
now
();
for
_
in
0
..
sample_len
{
let
logits
=
match
next_logits
.as_ref
()
{
Some
(
logits
)
=>
logits
,
None
=>
anyhow
::
bail!
(
"cannot work on an empty prompt"
),
};
let
logits
=
logits
.squeeze
(
0
)
?
.squeeze
(
0
)
?
.to_dtype
(
DType
::
F32
)
?
;
let
logits
=
if
self
.repeat_penalty
==
1
.
{
logits
}
else
{
let
start_at
=
tokens
.len
()
.saturating_sub
(
self
.repeat_last_n
);
candle_transformers
::
utils
::
apply_repeat_penalty
(
&
logits
,
self
.repeat_penalty
,
&
tokens
[
start_at
..
],
)
?
};
let
next_token
=
self
.logits_processor
.sample
(
&
logits
)
?
;
tokens
.push
(
next_token
);
generated_tokens
+=
1
;
if
next_token
==
EOS_TOKEN_ID
||
next_token
==
0
{
break
;
}
print!
(
"{}"
,
self
.tokenizer
.decode
(
&
[
next_token
])
?
);
std
::
io
::
stdout
()
.flush
()
?
;
let
input
=
Tensor
::
new
(
&
[[
next_token
]],
&
self
.device
)
?
;
next_logits
=
Some
(
self
.model
.forward
(
&
input
,
&
mut
state
)
?
)
}
let
dt
=
start_gen
.elapsed
();
println!
(
"
\n
{generated_tokens} tokens generated ({:.2} token/s)"
,
generated_tokens
as
f64
/
dt
.as_secs_f64
(),
);
Ok
(())
}
}
#[derive(Parser,
ValueEnum,
Clone,
Copy,
PartialEq,
Eq,
Debug)]
enum
Which
{
Eagle7b
,
World1b5
,
World3b
,
World6_1b6
,
}
impl
std
::
fmt
::
Display
for
Which
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
write!
(
f
,
"{:?}"
,
self
)
}
}
impl
Which
{
fn
model_id
(
&
self
)
->
&
'static
str
{
match
self
{
Self
::
Eagle7b
=>
"RWKV/v5-Eagle-7B-HF"
,
Self
::
World1b5
=>
"RWKV/rwkv-5-world-1b5"
,
Self
::
World3b
=>
"RWKV/rwkv-5-world-3b"
,
Self
::
World6_1b6
=>
"paperfun/rwkv"
,
}
}
fn
revision
(
&
self
)
->
&
'static
str
{
match
self
{
Self
::
Eagle7b
=>
"refs/pr/1"
,
Self
::
World1b5
|
Self
::
World3b
=>
"refs/pr/2"
,
Self
::
World6_1b6
=>
"main"
,
}
}
}
#[derive(Parser,
Debug)]
#[command(author,
version,
about,
long_about
=
None)]
struct
Args
{
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu
:
bool
,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing
:
bool
,
#[arg(long)]
prompt
:
String
,
/// The temperature used to generate samples.
#[arg(long)]
temperature
:
Option
<
f64
>
,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p
:
Option
<
f64
>
,
/// The seed to use when generating random samples.
#[arg(long,
default_value_t
=
299792458
)]
seed
:
u64
,
/// The length of the sample to generate (in tokens).
#[arg(long,
short
=
'n'
,
default_value_t
=
5000
)]
sample_len
:
usize
,
#[arg(long,
default_value
=
"world1b5"
)]
which
:
Which
,
#[arg(long)]
model_id
:
Option
<
String
>
,
#[arg(long)]
revision
:
Option
<
String
>
,
#[arg(long)]
tokenizer
:
Option
<
String
>
,
#[arg(long)]
weight_files
:
Option
<
String
>
,
#[arg(long)]
config_file
:
Option
<
String
>
,
#[arg(long)]
quantized
:
bool
,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long,
default_value_t
=
1.1
)]
repeat_penalty
:
f32
,
/// The context size to consider for the repeat penalty.
#[arg(long,
default_value_t
=
64
)]
repeat_last_n
:
usize
,
}
fn
main
()
->
Result
<
()
>
{
use
tracing_chrome
::
ChromeLayerBuilder
;
use
tracing_subscriber
::
prelude
::
*
;
let
args
=
Args
::
parse
();
let
_
guard
=
if
args
.tracing
{
let
(
chrome_layer
,
guard
)
=
ChromeLayerBuilder
::
new
()
.build
();
tracing_subscriber
::
registry
()
.with
(
chrome_layer
)
.init
();
Some
(
guard
)
}
else
{
None
};
println!
(
"avx: {}, neon: {}, simd128: {}, f16c: {}"
,
candle
::
utils
::
with_avx
(),
candle
::
utils
::
with_neon
(),
candle
::
utils
::
with_simd128
(),
candle
::
utils
::
with_f16c
()
);
println!
(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}"
,
args
.temperature
.unwrap_or
(
0
.
),
args
.repeat_penalty
,
args
.repeat_last_n
);
let
start
=
std
::
time
::
Instant
::
now
();
let
api
=
Api
::
new
()
?
;
let
repo
=
api
.repo
(
Repo
::
with_revision
(
args
.model_id
.unwrap_or_else
(||
args
.which
.model_id
()
.to_string
()),
RepoType
::
Model
,
args
.revision
.unwrap_or_else
(||
args
.which
.revision
()
.to_string
()),
));
let
tokenizer
=
match
args
.tokenizer
{
Some
(
file
)
=>
std
::
path
::
PathBuf
::
from
(
file
),
None
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"rwkv_vocab_v20230424.json"
)
?
,
};
let
config_filename
=
match
args
.config_file
{
Some
(
file
)
=>
std
::
path
::
PathBuf
::
from
(
file
),
None
=>
repo
.get
(
"config.json"
)
?
,
};
let
filenames
=
match
args
.weight_files
{
Some
(
files
)
=>
files
.split
(
','
)
.map
(
std
::
path
::
PathBuf
::
from
)
.collect
::
<
Vec
<
_
>>
(),
None
=>
{
if
args
.quantized
{
vec!
[
match
args
.which
{
Which
::
World1b5
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"world1b5-q4k.gguf"
)
?
,
Which
::
World3b
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"world3b-q4k.gguf"
)
?
,
Which
::
Eagle7b
=>
api
.model
(
"lmz/candle-rwkv"
.to_string
())
.get
(
"eagle7b-q4k.gguf"
)
?
,
Which
::
World6_1b6
=>
repo
.get
(
"rwkv-6-world-1b6-q4k.gguf"
)
?
,
}]
}
else
{
vec!
[
match
args
.which
{
Which
::
World1b5
|
Which
::
World3b
|
Which
::
Eagle7b
=>
{
repo
.get
(
"model.safetensors"
)
?
}
Which
::
World6_1b6
=>
repo
.get
(
"rwkv-6-world-1b6.safetensors"
)
?
,
}]
}
}
};
println!
(
"retrieved the files in {:?}"
,
start
.elapsed
());
let
tokenizer
=
Tokenizer
::
new
(
tokenizer
)
?
;
let
start
=
std
::
time
::
Instant
::
now
();
let
config
:
Config
=
serde_json
::
from_slice
(
&
std
::
fs
::
read
(
config_filename
)
?
)
?
;
let
device
=
candle_examples
::
device
(
args
.cpu
)
?
;
let
model
=
if
args
.quantized
{
let
filename
=
&
filenames
[
0
];
let
vb
=
candle_transformers
::
quantized_var_builder
::
VarBuilder
::
from_gguf
(
filename
,
&
device
)
?
;
match
args
.which
{
Which
::
World1b5
|
Which
::
World3b
|
Which
::
Eagle7b
=>
Model
::
Q5
(
Q5
::
new
(
&
config
,
vb
)
?
),
Which
::
World6_1b6
=>
Model
::
Q6
(
Q6
::
new
(
&
config
,
vb
)
?
),
}
}
else
{
let
vb
=
unsafe
{
VarBuilder
::
from_mmaped_safetensors
(
&
filenames
,
DType
::
F32
,
&
device
)
?
};
match
args
.which
{
Which
::
World1b5
|
Which
::
World3b
|
Which
::
Eagle7b
=>
Model
::
M5
(
M5
::
new
(
&
config
,
vb
)
?
),
Which
::
World6_1b6
=>
Model
::
M6
(
M6
::
new
(
&
config
,
vb
)
?
),
}
};
println!
(
"loaded the model in {:?}"
,
start
.elapsed
());
let
mut
pipeline
=
TextGeneration
::
new
(
model
,
config
,
tokenizer
,
args
.seed
,
args
.temperature
,
args
.top_p
,
args
.repeat_penalty
,
args
.repeat_last_n
,
&
device
,
);
pipeline
.run
(
&
args
.prompt
,
args
.sample_len
)
?
;
Ok
(())
}
candle-examples/examples/segformer/README.md
0 → 100644
View file @
25d2752f
# candle-segformer
-
[
HuggingFace Segformer Model Card
][
segformer
]
-
[
`mit-b0` - An encoder only pretrained model
][
encoder
]
-
[
`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation
][
ade512
]
## How to run the example
If you want you can use the example images from this
[
pull request
][
pr
]
, download them and supply the path to the image as an argument to the example.
```
bash
# run the image classification task
cargo run
--example
segformer classify <path-to-image>
# run the segmentation task
cargo run
--example
segformer segment <path-to-image>
```
Example output for classification:
```
text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
[
pr
]:
https://github.com/huggingface/candle/pull/1617
[
segformer
]:
https://huggingface.co/docs/transformers/model_doc/segformer
[
encoder
]:
https://huggingface.co/nvidia/mit-b0
[
ade512
]:
https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
candle-examples/examples/segformer/assets/labels.json
0 → 100644
View file @
25d2752f
[
{
"index"
:
1
,
"color"
:
"#787878"
,
"label"
:
"wall"
},
{
"index"
:
2
,
"color"
:
"#B47878"
,
"label"
:
"building;edifice"
},
{
"index"
:
3
,
"color"
:
"#06E6E6"
,
"label"
:
"sky"
},
{
"index"
:
4
,
"color"
:
"#503232"
,
"label"
:
"floor;flooring"
},
{
"index"
:
5
,
"color"
:
"#04C803"
,
"label"
:
"tree"
},
{
"index"
:
6
,
"color"
:
"#787850"
,
"label"
:
"ceiling"
},
{
"index"
:
7
,
"color"
:
"#8C8C8C"
,
"label"
:
"road;route"
},
{
"index"
:
8
,
"color"
:
"#CC05FF"
,
"label"
:
"bed"
},
{
"index"
:
9
,
"color"
:
"#E6E6E6"
,
"label"
:
"windowpane;window"
},
{
"index"
:
10
,
"color"
:
"#04FA07"
,
"label"
:
"grass"
},
{
"index"
:
11
,
"color"
:
"#E005FF"
,
"label"
:
"cabinet"
},
{
"index"
:
12
,
"color"
:
"#EBFF07"
,
"label"
:
"sidewalk;pavement"
},
{
"index"
:
13
,
"color"
:
"#96053D"
,
"label"
:
"person;individual;someone;somebody;mortal;soul"
},
{
"index"
:
14
,
"color"
:
"#787846"
,
"label"
:
"earth;ground"
},
{
"index"
:
15
,
"color"
:
"#08FF33"
,
"label"
:
"door;double;door"
},
{
"index"
:
16
,
"color"
:
"#FF0652"
,
"label"
:
"table"
},
{
"index"
:
17
,
"color"
:
"#8FFF8C"
,
"label"
:
"mountain;mount"
},
{
"index"
:
18
,
"color"
:
"#CCFF04"
,
"label"
:
"plant;flora;plant;life"
},
{
"index"
:
19
,
"color"
:
"#FF3307"
,
"label"
:
"curtain;drape;drapery;mantle;pall"
},
{
"index"
:
20
,
"color"
:
"#CC4603"
,
"label"
:
"chair"
},
{
"index"
:
21
,
"color"
:
"#0066C8"
,
"label"
:
"car;auto;automobile;machine;motorcar"
},
{
"index"
:
22
,
"color"
:
"#3DE6FA"
,
"label"
:
"water"
},
{
"index"
:
23
,
"color"
:
"#FF0633"
,
"label"
:
"painting;picture"
},
{
"index"
:
24
,
"color"
:
"#0B66FF"
,
"label"
:
"sofa;couch;lounge"
},
{
"index"
:
25
,
"color"
:
"#FF0747"
,
"label"
:
"shelf"
},
{
"index"
:
26
,
"color"
:
"#FF09E0"
,
"label"
:
"house"
},
{
"index"
:
27
,
"color"
:
"#0907E6"
,
"label"
:
"sea"
},
{
"index"
:
28
,
"color"
:
"#DCDCDC"
,
"label"
:
"mirror"
},
{
"index"
:
29
,
"color"
:
"#FF095C"
,
"label"
:
"rug;carpet;carpeting"
},
{
"index"
:
30
,
"color"
:
"#7009FF"
,
"label"
:
"field"
},
{
"index"
:
31
,
"color"
:
"#08FFD6"
,
"label"
:
"armchair"
},
{
"index"
:
32
,
"color"
:
"#07FFE0"
,
"label"
:
"seat"
},
{
"index"
:
33
,
"color"
:
"#FFB806"
,
"label"
:
"fence;fencing"
},
{
"index"
:
34
,
"color"
:
"#0AFF47"
,
"label"
:
"desk"
},
{
"index"
:
35
,
"color"
:
"#FF290A"
,
"label"
:
"rock;stone"
},
{
"index"
:
36
,
"color"
:
"#07FFFF"
,
"label"
:
"wardrobe;closet;press"
},
{
"index"
:
37
,
"color"
:
"#E0FF08"
,
"label"
:
"lamp"
},
{
"index"
:
38
,
"color"
:
"#6608FF"
,
"label"
:
"bathtub;bathing;tub;bath;tub"
},
{
"index"
:
39
,
"color"
:
"#FF3D06"
,
"label"
:
"railing;rail"
},
{
"index"
:
40
,
"color"
:
"#FFC207"
,
"label"
:
"cushion"
},
{
"index"
:
41
,
"color"
:
"#FF7A08"
,
"label"
:
"base;pedestal;stand"
},
{
"index"
:
42
,
"color"
:
"#00FF14"
,
"label"
:
"box"
},
{
"index"
:
43
,
"color"
:
"#FF0829"
,
"label"
:
"column;pillar"
},
{
"index"
:
44
,
"color"
:
"#FF0599"
,
"label"
:
"signboard;sign"
},
{
"index"
:
45
,
"color"
:
"#0633FF"
,
"label"
:
"chest;of;drawers;chest;bureau;dresser"
},
{
"index"
:
46
,
"color"
:
"#EB0CFF"
,
"label"
:
"counter"
},
{
"index"
:
47
,
"color"
:
"#A09614"
,
"label"
:
"sand"
},
{
"index"
:
48
,
"color"
:
"#00A3FF"
,
"label"
:
"sink"
},
{
"index"
:
49
,
"color"
:
"#8C8C8C"
,
"label"
:
"skyscraper"
},
{
"index"
:
50
,
"color"
:
"#FA0A0F"
,
"label"
:
"fireplace;hearth;open;fireplace"
},
{
"index"
:
51
,
"color"
:
"#14FF00"
,
"label"
:
"refrigerator;icebox"
},
{
"index"
:
52
,
"color"
:
"#1FFF00"
,
"label"
:
"grandstand;covered;stand"
},
{
"index"
:
53
,
"color"
:
"#FF1F00"
,
"label"
:
"path"
},
{
"index"
:
54
,
"color"
:
"#FFE000"
,
"label"
:
"stairs;steps"
},
{
"index"
:
55
,
"color"
:
"#99FF00"
,
"label"
:
"runway"
},
{
"index"
:
56
,
"color"
:
"#0000FF"
,
"label"
:
"case;display;case;showcase;vitrine"
},
{
"index"
:
57
,
"color"
:
"#FF4700"
,
"label"
:
"pool;table;billiard;table;snooker;table"
},
{
"index"
:
58
,
"color"
:
"#00EBFF"
,
"label"
:
"pillow"
},
{
"index"
:
59
,
"color"
:
"#00ADFF"
,
"label"
:
"screen;door;screen"
},
{
"index"
:
60
,
"color"
:
"#1F00FF"
,
"label"
:
"stairway;staircase"
},
{
"index"
:
61
,
"color"
:
"#0BC8C8"
,
"label"
:
"river"
},
{
"index"
:
62
,
"color"
:
"#FF5200"
,
"label"
:
"bridge;span"
},
{
"index"
:
63
,
"color"
:
"#00FFF5"
,
"label"
:
"bookcase"
},
{
"index"
:
64
,
"color"
:
"#003DFF"
,
"label"
:
"blind;screen"
},
{
"index"
:
65
,
"color"
:
"#00FF70"
,
"label"
:
"coffee;table;cocktail;table"
},
{
"index"
:
66
,
"color"
:
"#00FF85"
,
"label"
:
"toilet;can;commode;crapper;pot;potty;stool;throne"
},
{
"index"
:
67
,
"color"
:
"#FF0000"
,
"label"
:
"flower"
},
{
"index"
:
68
,
"color"
:
"#FFA300"
,
"label"
:
"book"
},
{
"index"
:
69
,
"color"
:
"#FF6600"
,
"label"
:
"hill"
},
{
"index"
:
70
,
"color"
:
"#C2FF00"
,
"label"
:
"bench"
},
{
"index"
:
71
,
"color"
:
"#008FFF"
,
"label"
:
"countertop"
},
{
"index"
:
72
,
"color"
:
"#33FF00"
,
"label"
:
"stove;kitchen;stove;range;kitchen;range;cooking;stove"
},
{
"index"
:
73
,
"color"
:
"#0052FF"
,
"label"
:
"palm;palm;tree"
},
{
"index"
:
74
,
"color"
:
"#00FF29"
,
"label"
:
"kitchen;island"
},
{
"index"
:
75
,
"color"
:
"#00FFAD"
,
"label"
:
"computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
},
{
"index"
:
76
,
"color"
:
"#0A00FF"
,
"label"
:
"swivel;chair"
},
{
"index"
:
77
,
"color"
:
"#ADFF00"
,
"label"
:
"boat"
},
{
"index"
:
78
,
"color"
:
"#00FF99"
,
"label"
:
"bar"
},
{
"index"
:
79
,
"color"
:
"#FF5C00"
,
"label"
:
"arcade;machine"
},
{
"index"
:
80
,
"color"
:
"#FF00FF"
,
"label"
:
"hovel;hut;hutch;shack;shanty"
},
{
"index"
:
81
,
"color"
:
"#FF00F5"
,
"label"
:
"bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
},
{
"index"
:
82
,
"color"
:
"#FF0066"
,
"label"
:
"towel"
},
{
"index"
:
83
,
"color"
:
"#FFAD00"
,
"label"
:
"light;light;source"
},
{
"index"
:
84
,
"color"
:
"#FF0014"
,
"label"
:
"truck;motortruck"
},
{
"index"
:
85
,
"color"
:
"#FFB8B8"
,
"label"
:
"tower"
},
{
"index"
:
86
,
"color"
:
"#001FFF"
,
"label"
:
"chandelier;pendant;pendent"
},
{
"index"
:
87
,
"color"
:
"#00FF3D"
,
"label"
:
"awning;sunshade;sunblind"
},
{
"index"
:
88
,
"color"
:
"#0047FF"
,
"label"
:
"streetlight;street;lamp"
},
{
"index"
:
89
,
"color"
:
"#FF00CC"
,
"label"
:
"booth;cubicle;stall;kiosk"
},
{
"index"
:
90
,
"color"
:
"#00FFC2"
,
"label"
:
"television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
},
{
"index"
:
91
,
"color"
:
"#00FF52"
,
"label"
:
"airplane;aeroplane;plane"
},
{
"index"
:
92
,
"color"
:
"#000AFF"
,
"label"
:
"dirt;track"
},
{
"index"
:
93
,
"color"
:
"#0070FF"
,
"label"
:
"apparel;wearing;apparel;dress;clothes"
},
{
"index"
:
94
,
"color"
:
"#3300FF"
,
"label"
:
"pole"
},
{
"index"
:
95
,
"color"
:
"#00C2FF"
,
"label"
:
"land;ground;soil"
},
{
"index"
:
96
,
"color"
:
"#007AFF"
,
"label"
:
"bannister;banister;balustrade;balusters;handrail"
},
{
"index"
:
97
,
"color"
:
"#00FFA3"
,
"label"
:
"escalator;moving;staircase;moving;stairway"
},
{
"index"
:
98
,
"color"
:
"#FF9900"
,
"label"
:
"ottoman;pouf;pouffe;puff;hassock"
},
{
"index"
:
99
,
"color"
:
"#00FF0A"
,
"label"
:
"bottle"
},
{
"index"
:
100
,
"color"
:
"#FF7000"
,
"label"
:
"buffet;counter;sideboard"
},
{
"index"
:
101
,
"color"
:
"#8FFF00"
,
"label"
:
"poster;posting;placard;notice;bill;card"
},
{
"index"
:
102
,
"color"
:
"#5200FF"
,
"label"
:
"stage"
},
{
"index"
:
103
,
"color"
:
"#A3FF00"
,
"label"
:
"van"
},
{
"index"
:
104
,
"color"
:
"#FFEB00"
,
"label"
:
"ship"
},
{
"index"
:
105
,
"color"
:
"#08B8AA"
,
"label"
:
"fountain"
},
{
"index"
:
106
,
"color"
:
"#8500FF"
,
"label"
:
"conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
},
{
"index"
:
107
,
"color"
:
"#00FF5C"
,
"label"
:
"canopy"
},
{
"index"
:
108
,
"color"
:
"#B800FF"
,
"label"
:
"washer;automatic;washer;washing;machine"
},
{
"index"
:
109
,
"color"
:
"#FF001F"
,
"label"
:
"plaything;toy"
},
{
"index"
:
110
,
"color"
:
"#00B8FF"
,
"label"
:
"swimming;pool;swimming;bath;natatorium"
},
{
"index"
:
111
,
"color"
:
"#00D6FF"
,
"label"
:
"stool"
},
{
"index"
:
112
,
"color"
:
"#FF0070"
,
"label"
:
"barrel;cask"
},
{
"index"
:
113
,
"color"
:
"#5CFF00"
,
"label"
:
"basket;handbasket"
},
{
"index"
:
114
,
"color"
:
"#00E0FF"
,
"label"
:
"waterfall;falls"
},
{
"index"
:
115
,
"color"
:
"#70E0FF"
,
"label"
:
"tent;collapsible;shelter"
},
{
"index"
:
116
,
"color"
:
"#46B8A0"
,
"label"
:
"bag"
},
{
"index"
:
117
,
"color"
:
"#A300FF"
,
"label"
:
"minibike;motorbike"
},
{
"index"
:
118
,
"color"
:
"#9900FF"
,
"label"
:
"cradle"
},
{
"index"
:
119
,
"color"
:
"#47FF00"
,
"label"
:
"oven"
},
{
"index"
:
120
,
"color"
:
"#FF00A3"
,
"label"
:
"ball"
},
{
"index"
:
121
,
"color"
:
"#FFCC00"
,
"label"
:
"food;solid;food"
},
{
"index"
:
122
,
"color"
:
"#FF008F"
,
"label"
:
"step;stair"
},
{
"index"
:
123
,
"color"
:
"#00FFEB"
,
"label"
:
"tank;storage;tank"
},
{
"index"
:
124
,
"color"
:
"#85FF00"
,
"label"
:
"trade;name;brand;name;brand;marque"
},
{
"index"
:
125
,
"color"
:
"#FF00EB"
,
"label"
:
"microwave;microwave;oven"
},
{
"index"
:
126
,
"color"
:
"#F500FF"
,
"label"
:
"pot;flowerpot"
},
{
"index"
:
127
,
"color"
:
"#FF007A"
,
"label"
:
"animal;animate;being;beast;brute;creature;fauna"
},
{
"index"
:
128
,
"color"
:
"#FFF500"
,
"label"
:
"bicycle;bike;wheel;cycle"
},
{
"index"
:
129
,
"color"
:
"#0ABED4"
,
"label"
:
"lake"
},
{
"index"
:
130
,
"color"
:
"#D6FF00"
,
"label"
:
"dishwasher;dish;washer;dishwashing;machine"
},
{
"index"
:
131
,
"color"
:
"#00CCFF"
,
"label"
:
"screen;silver;screen;projection;screen"
},
{
"index"
:
132
,
"color"
:
"#1400FF"
,
"label"
:
"blanket;cover"
},
{
"index"
:
133
,
"color"
:
"#FFFF00"
,
"label"
:
"sculpture"
},
{
"index"
:
134
,
"color"
:
"#0099FF"
,
"label"
:
"hood;exhaust;hood"
},
{
"index"
:
135
,
"color"
:
"#0029FF"
,
"label"
:
"sconce"
},
{
"index"
:
136
,
"color"
:
"#00FFCC"
,
"label"
:
"vase"
},
{
"index"
:
137
,
"color"
:
"#2900FF"
,
"label"
:
"traffic;light;traffic;signal;stoplight"
},
{
"index"
:
138
,
"color"
:
"#29FF00"
,
"label"
:
"tray"
},
{
"index"
:
139
,
"color"
:
"#AD00FF"
,
"label"
:
"ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
},
{
"index"
:
140
,
"color"
:
"#00F5FF"
,
"label"
:
"fan"
},
{
"index"
:
141
,
"color"
:
"#4700FF"
,
"label"
:
"pier;wharf;wharfage;dock"
},
{
"index"
:
142
,
"color"
:
"#7A00FF"
,
"label"
:
"crt;screen"
},
{
"index"
:
143
,
"color"
:
"#00FFB8"
,
"label"
:
"plate"
},
{
"index"
:
144
,
"color"
:
"#005CFF"
,
"label"
:
"monitor;monitoring;device"
},
{
"index"
:
145
,
"color"
:
"#B8FF00"
,
"label"
:
"bulletin;board;notice;board"
},
{
"index"
:
146
,
"color"
:
"#0085FF"
,
"label"
:
"shower"
},
{
"index"
:
147
,
"color"
:
"#FFD600"
,
"label"
:
"radiator"
},
{
"index"
:
148
,
"color"
:
"#19C2C2"
,
"label"
:
"glass;drinking;glass"
},
{
"index"
:
149
,
"color"
:
"#66FF00"
,
"label"
:
"clock"
},
{
"index"
:
150
,
"color"
:
"#5C00FF"
,
"label"
:
"flag"
}
]
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment