Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
pymc3
Commits
2204de4f
Commit
2204de4f
authored
Aug 15, 2024
by
wangsen
Browse files
readme.md
parents
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
0 deletions
+84
-0
README.md
README.md
+60
-0
test_gpu.py
test_gpu.py
+5
-0
test_sample_time.py
test_sample_time.py
+19
-0
No files found.
README.md
0 → 100644
View file @
2204de4f
# 下载jax和jaxlib
```
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp39-cp39-manylinux_2_31_x86_64.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
```
# conda 环境
```
conda create -n pymc3 python=3.9
conda activate pymc3
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp39-cp39-manylinux_2_31_x86_64.whl
pip install pymc==5.9.1
pip install numpyro==0.14.0
pip install seaborn==0.13.2
pip install scipy==1.12.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
```
# 测试jax是否在gpu上可行:
```
import jax
import pymc3 as pm
jax.default_backend()
jax.devices()
```
# 测试采样时间,GPU负载等情况
```
import jax
import pymc as pm
import numpy as np
import pytensor as pt
pt.config.floatX = "float32"
np.random.seed(123)
n =10000
X = np.random.randn(n)
Y =3* X + np.random.randn(n)
# 定义PyMC3模型
with pm.Model() as model:
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10)
sigma = pm.HalfNormal('sigma', sigma=1)
mu = alpha + beta * X
Y_obs = pm.Normal('Y_obs', mu=mu, sigma=sigma, observed=Y)
trace = pm.sample(1000, nuts_sampler="numpyro",return_inferencedata=False)
print(pm.summary(trace))
```
\ No newline at end of file
test_gpu.py
0 → 100644
View file @
2204de4f
import
jax
import
pymc
as
pm
print
(
jax
.
default_backend
())
print
(
jax
.
devices
())
\ No newline at end of file
test_sample_time.py
0 → 100644
View file @
2204de4f
import
jax
import
pymc
as
pm
import
numpy
as
np
import
pytensor
as
pt
pt
.
config
.
floatX
=
"float32"
np
.
random
.
seed
(
123
)
n
=
10000
X
=
np
.
random
.
randn
(
n
)
Y
=
3
*
X
+
np
.
random
.
randn
(
n
)
# 定义PyMC3模型
with
pm
.
Model
()
as
model
:
alpha
=
pm
.
Normal
(
'alpha'
,
mu
=
0
,
sigma
=
10
)
beta
=
pm
.
Normal
(
'beta'
,
mu
=
0
,
sigma
=
10
)
sigma
=
pm
.
HalfNormal
(
'sigma'
,
sigma
=
1
)
mu
=
alpha
+
beta
*
X
Y_obs
=
pm
.
Normal
(
'Y_obs'
,
mu
=
mu
,
sigma
=
sigma
,
observed
=
Y
)
trace
=
pm
.
sample
(
1000
,
nuts_sampler
=
"numpyro"
,
return_inferencedata
=
False
)
print
(
pm
.
summary
(
trace
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment